• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Tests for unix_events.py."""
2
3import collections
4import errno
5import io
6import os
7import pathlib
8import signal
9import socket
10import stat
11import sys
12import tempfile
13import threading
14import unittest
15import warnings
16from unittest import mock
17
18if sys.platform == 'win32':
19    raise unittest.SkipTest('UNIX only')
20
21
22import asyncio
23from asyncio import log
24from asyncio import test_utils
25from asyncio import unix_events
26
27
28MOCK_ANY = mock.ANY
29
30
31def close_pipe_transport(transport):
32    # Don't call transport.close() because the event loop and the selector
33    # are mocked
34    if transport._pipe is None:
35        return
36    transport._pipe.close()
37    transport._pipe = None
38
39
40@unittest.skipUnless(signal, 'Signals are not supported')
41class SelectorEventLoopSignalTests(test_utils.TestCase):
42
43    def setUp(self):
44        super().setUp()
45        self.loop = asyncio.SelectorEventLoop()
46        self.set_event_loop(self.loop)
47
48    def test_check_signal(self):
49        self.assertRaises(
50            TypeError, self.loop._check_signal, '1')
51        self.assertRaises(
52            ValueError, self.loop._check_signal, signal.NSIG + 1)
53
54    def test_handle_signal_no_handler(self):
55        self.loop._handle_signal(signal.NSIG + 1)
56
57    def test_handle_signal_cancelled_handler(self):
58        h = asyncio.Handle(mock.Mock(), (),
59                           loop=mock.Mock())
60        h.cancel()
61        self.loop._signal_handlers[signal.NSIG + 1] = h
62        self.loop.remove_signal_handler = mock.Mock()
63        self.loop._handle_signal(signal.NSIG + 1)
64        self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1)
65
66    @mock.patch('asyncio.unix_events.signal')
67    def test_add_signal_handler_setup_error(self, m_signal):
68        m_signal.NSIG = signal.NSIG
69        m_signal.set_wakeup_fd.side_effect = ValueError
70
71        self.assertRaises(
72            RuntimeError,
73            self.loop.add_signal_handler,
74            signal.SIGINT, lambda: True)
75
76    @mock.patch('asyncio.unix_events.signal')
77    def test_add_signal_handler_coroutine_error(self, m_signal):
78        m_signal.NSIG = signal.NSIG
79
80        @asyncio.coroutine
81        def simple_coroutine():
82            yield from []
83
84        # callback must not be a coroutine function
85        coro_func = simple_coroutine
86        coro_obj = coro_func()
87        self.addCleanup(coro_obj.close)
88        for func in (coro_func, coro_obj):
89            self.assertRaisesRegex(
90                TypeError, 'coroutines cannot be used with add_signal_handler',
91                self.loop.add_signal_handler,
92                signal.SIGINT, func)
93
94    @mock.patch('asyncio.unix_events.signal')
95    def test_add_signal_handler(self, m_signal):
96        m_signal.NSIG = signal.NSIG
97
98        cb = lambda: True
99        self.loop.add_signal_handler(signal.SIGHUP, cb)
100        h = self.loop._signal_handlers.get(signal.SIGHUP)
101        self.assertIsInstance(h, asyncio.Handle)
102        self.assertEqual(h._callback, cb)
103
104    @mock.patch('asyncio.unix_events.signal')
105    def test_add_signal_handler_install_error(self, m_signal):
106        m_signal.NSIG = signal.NSIG
107
108        def set_wakeup_fd(fd):
109            if fd == -1:
110                raise ValueError()
111        m_signal.set_wakeup_fd = set_wakeup_fd
112
113        class Err(OSError):
114            errno = errno.EFAULT
115        m_signal.signal.side_effect = Err
116
117        self.assertRaises(
118            Err,
119            self.loop.add_signal_handler,
120            signal.SIGINT, lambda: True)
121
122    @mock.patch('asyncio.unix_events.signal')
123    @mock.patch('asyncio.base_events.logger')
124    def test_add_signal_handler_install_error2(self, m_logging, m_signal):
125        m_signal.NSIG = signal.NSIG
126
127        class Err(OSError):
128            errno = errno.EINVAL
129        m_signal.signal.side_effect = Err
130
131        self.loop._signal_handlers[signal.SIGHUP] = lambda: True
132        self.assertRaises(
133            RuntimeError,
134            self.loop.add_signal_handler,
135            signal.SIGINT, lambda: True)
136        self.assertFalse(m_logging.info.called)
137        self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
138
139    @mock.patch('asyncio.unix_events.signal')
140    @mock.patch('asyncio.base_events.logger')
141    def test_add_signal_handler_install_error3(self, m_logging, m_signal):
142        class Err(OSError):
143            errno = errno.EINVAL
144        m_signal.signal.side_effect = Err
145        m_signal.NSIG = signal.NSIG
146
147        self.assertRaises(
148            RuntimeError,
149            self.loop.add_signal_handler,
150            signal.SIGINT, lambda: True)
151        self.assertFalse(m_logging.info.called)
152        self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
153
154    @mock.patch('asyncio.unix_events.signal')
155    def test_remove_signal_handler(self, m_signal):
156        m_signal.NSIG = signal.NSIG
157
158        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
159
160        self.assertTrue(
161            self.loop.remove_signal_handler(signal.SIGHUP))
162        self.assertTrue(m_signal.set_wakeup_fd.called)
163        self.assertTrue(m_signal.signal.called)
164        self.assertEqual(
165            (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
166
167    @mock.patch('asyncio.unix_events.signal')
168    def test_remove_signal_handler_2(self, m_signal):
169        m_signal.NSIG = signal.NSIG
170        m_signal.SIGINT = signal.SIGINT
171
172        self.loop.add_signal_handler(signal.SIGINT, lambda: True)
173        self.loop._signal_handlers[signal.SIGHUP] = object()
174        m_signal.set_wakeup_fd.reset_mock()
175
176        self.assertTrue(
177            self.loop.remove_signal_handler(signal.SIGINT))
178        self.assertFalse(m_signal.set_wakeup_fd.called)
179        self.assertTrue(m_signal.signal.called)
180        self.assertEqual(
181            (signal.SIGINT, m_signal.default_int_handler),
182            m_signal.signal.call_args[0])
183
184    @mock.patch('asyncio.unix_events.signal')
185    @mock.patch('asyncio.base_events.logger')
186    def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
187        m_signal.NSIG = signal.NSIG
188        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
189
190        m_signal.set_wakeup_fd.side_effect = ValueError
191
192        self.loop.remove_signal_handler(signal.SIGHUP)
193        self.assertTrue(m_logging.info)
194
195    @mock.patch('asyncio.unix_events.signal')
196    def test_remove_signal_handler_error(self, m_signal):
197        m_signal.NSIG = signal.NSIG
198        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
199
200        m_signal.signal.side_effect = OSError
201
202        self.assertRaises(
203            OSError, self.loop.remove_signal_handler, signal.SIGHUP)
204
205    @mock.patch('asyncio.unix_events.signal')
206    def test_remove_signal_handler_error2(self, m_signal):
207        m_signal.NSIG = signal.NSIG
208        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
209
210        class Err(OSError):
211            errno = errno.EINVAL
212        m_signal.signal.side_effect = Err
213
214        self.assertRaises(
215            RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP)
216
217    @mock.patch('asyncio.unix_events.signal')
218    def test_close(self, m_signal):
219        m_signal.NSIG = signal.NSIG
220
221        self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
222        self.loop.add_signal_handler(signal.SIGCHLD, lambda: True)
223
224        self.assertEqual(len(self.loop._signal_handlers), 2)
225
226        m_signal.set_wakeup_fd.reset_mock()
227
228        self.loop.close()
229
230        self.assertEqual(len(self.loop._signal_handlers), 0)
231        m_signal.set_wakeup_fd.assert_called_once_with(-1)
232
233
234@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
235                     'UNIX Sockets are not supported')
236class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
237
238    def setUp(self):
239        super().setUp()
240        self.loop = asyncio.SelectorEventLoop()
241        self.set_event_loop(self.loop)
242
243    def test_create_unix_server_existing_path_sock(self):
244        with test_utils.unix_socket_path() as path:
245            sock = socket.socket(socket.AF_UNIX)
246            sock.bind(path)
247            sock.listen(1)
248            sock.close()
249
250            coro = self.loop.create_unix_server(lambda: None, path)
251            srv = self.loop.run_until_complete(coro)
252            srv.close()
253            self.loop.run_until_complete(srv.wait_closed())
254
255    @unittest.skipUnless(hasattr(os, 'fspath'), 'no os.fspath')
256    def test_create_unix_server_pathlib(self):
257        with test_utils.unix_socket_path() as path:
258            path = pathlib.Path(path)
259            srv_coro = self.loop.create_unix_server(lambda: None, path)
260            srv = self.loop.run_until_complete(srv_coro)
261            srv.close()
262            self.loop.run_until_complete(srv.wait_closed())
263
264    def test_create_unix_server_existing_path_nonsock(self):
265        with tempfile.NamedTemporaryFile() as file:
266            coro = self.loop.create_unix_server(lambda: None, file.name)
267            with self.assertRaisesRegex(OSError,
268                                        'Address.*is already in use'):
269                self.loop.run_until_complete(coro)
270
271    def test_create_unix_server_ssl_bool(self):
272        coro = self.loop.create_unix_server(lambda: None, path='spam',
273                                            ssl=True)
274        with self.assertRaisesRegex(TypeError,
275                                    'ssl argument must be an SSLContext'):
276            self.loop.run_until_complete(coro)
277
278    def test_create_unix_server_nopath_nosock(self):
279        coro = self.loop.create_unix_server(lambda: None, path=None)
280        with self.assertRaisesRegex(ValueError,
281                                    'path was not specified, and no sock'):
282            self.loop.run_until_complete(coro)
283
284    def test_create_unix_server_path_inetsock(self):
285        sock = socket.socket()
286        with sock:
287            coro = self.loop.create_unix_server(lambda: None, path=None,
288                                                sock=sock)
289            with self.assertRaisesRegex(ValueError,
290                                        'A UNIX Domain Stream.*was expected'):
291                self.loop.run_until_complete(coro)
292
293    def test_create_unix_server_path_dgram(self):
294        sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
295        with sock:
296            coro = self.loop.create_unix_server(lambda: None, path=None,
297                                                sock=sock)
298            with self.assertRaisesRegex(ValueError,
299                                        'A UNIX Domain Stream.*was expected'):
300                self.loop.run_until_complete(coro)
301
302    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
303                         'no socket.SOCK_NONBLOCK (linux only)')
304    def test_create_unix_server_path_stream_bittype(self):
305        sock = socket.socket(
306            socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
307        with tempfile.NamedTemporaryFile() as file:
308            fn = file.name
309        try:
310            with sock:
311                sock.bind(fn)
312                coro = self.loop.create_unix_server(lambda: None, path=None,
313                                                    sock=sock)
314                srv = self.loop.run_until_complete(coro)
315                srv.close()
316                self.loop.run_until_complete(srv.wait_closed())
317        finally:
318            os.unlink(fn)
319
320    def test_create_unix_connection_path_inetsock(self):
321        sock = socket.socket()
322        with sock:
323            coro = self.loop.create_unix_connection(lambda: None, path=None,
324                                                    sock=sock)
325            with self.assertRaisesRegex(ValueError,
326                                        'A UNIX Domain Stream.*was expected'):
327                self.loop.run_until_complete(coro)
328
329    @mock.patch('asyncio.unix_events.socket')
330    def test_create_unix_server_bind_error(self, m_socket):
331        # Ensure that the socket is closed on any bind error
332        sock = mock.Mock()
333        m_socket.socket.return_value = sock
334
335        sock.bind.side_effect = OSError
336        coro = self.loop.create_unix_server(lambda: None, path="/test")
337        with self.assertRaises(OSError):
338            self.loop.run_until_complete(coro)
339        self.assertTrue(sock.close.called)
340
341        sock.bind.side_effect = MemoryError
342        coro = self.loop.create_unix_server(lambda: None, path="/test")
343        with self.assertRaises(MemoryError):
344            self.loop.run_until_complete(coro)
345        self.assertTrue(sock.close.called)
346
347    def test_create_unix_connection_path_sock(self):
348        coro = self.loop.create_unix_connection(
349            lambda: None, os.devnull, sock=object())
350        with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
351            self.loop.run_until_complete(coro)
352
353    def test_create_unix_connection_nopath_nosock(self):
354        coro = self.loop.create_unix_connection(
355            lambda: None, None)
356        with self.assertRaisesRegex(ValueError,
357                                    'no path and sock were specified'):
358            self.loop.run_until_complete(coro)
359
360    def test_create_unix_connection_nossl_serverhost(self):
361        coro = self.loop.create_unix_connection(
362            lambda: None, os.devnull, server_hostname='spam')
363        with self.assertRaisesRegex(ValueError,
364                                    'server_hostname is only meaningful'):
365            self.loop.run_until_complete(coro)
366
367    def test_create_unix_connection_ssl_noserverhost(self):
368        coro = self.loop.create_unix_connection(
369            lambda: None, os.devnull, ssl=True)
370
371        with self.assertRaisesRegex(
372            ValueError, 'you have to pass server_hostname when using ssl'):
373
374            self.loop.run_until_complete(coro)
375
376
377class UnixReadPipeTransportTests(test_utils.TestCase):
378
379    def setUp(self):
380        super().setUp()
381        self.loop = self.new_test_loop()
382        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
383        self.pipe = mock.Mock(spec_set=io.RawIOBase)
384        self.pipe.fileno.return_value = 5
385
386        blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking')
387        blocking_patcher.start()
388        self.addCleanup(blocking_patcher.stop)
389
390        fstat_patcher = mock.patch('os.fstat')
391        m_fstat = fstat_patcher.start()
392        st = mock.Mock()
393        st.st_mode = stat.S_IFIFO
394        m_fstat.return_value = st
395        self.addCleanup(fstat_patcher.stop)
396
397    def read_pipe_transport(self, waiter=None):
398        transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe,
399                                                       self.protocol,
400                                                       waiter=waiter)
401        self.addCleanup(close_pipe_transport, transport)
402        return transport
403
404    def test_ctor(self):
405        waiter = asyncio.Future(loop=self.loop)
406        tr = self.read_pipe_transport(waiter=waiter)
407        self.loop.run_until_complete(waiter)
408
409        self.protocol.connection_made.assert_called_with(tr)
410        self.loop.assert_reader(5, tr._read_ready)
411        self.assertIsNone(waiter.result())
412
413    @mock.patch('os.read')
414    def test__read_ready(self, m_read):
415        tr = self.read_pipe_transport()
416        m_read.return_value = b'data'
417        tr._read_ready()
418
419        m_read.assert_called_with(5, tr.max_size)
420        self.protocol.data_received.assert_called_with(b'data')
421
422    @mock.patch('os.read')
423    def test__read_ready_eof(self, m_read):
424        tr = self.read_pipe_transport()
425        m_read.return_value = b''
426        tr._read_ready()
427
428        m_read.assert_called_with(5, tr.max_size)
429        self.assertFalse(self.loop.readers)
430        test_utils.run_briefly(self.loop)
431        self.protocol.eof_received.assert_called_with()
432        self.protocol.connection_lost.assert_called_with(None)
433
434    @mock.patch('os.read')
435    def test__read_ready_blocked(self, m_read):
436        tr = self.read_pipe_transport()
437        m_read.side_effect = BlockingIOError
438        tr._read_ready()
439
440        m_read.assert_called_with(5, tr.max_size)
441        test_utils.run_briefly(self.loop)
442        self.assertFalse(self.protocol.data_received.called)
443
444    @mock.patch('asyncio.log.logger.error')
445    @mock.patch('os.read')
446    def test__read_ready_error(self, m_read, m_logexc):
447        tr = self.read_pipe_transport()
448        err = OSError()
449        m_read.side_effect = err
450        tr._close = mock.Mock()
451        tr._read_ready()
452
453        m_read.assert_called_with(5, tr.max_size)
454        tr._close.assert_called_with(err)
455        m_logexc.assert_called_with(
456            test_utils.MockPattern(
457                'Fatal read error on pipe transport'
458                '\nprotocol:.*\ntransport:.*'),
459            exc_info=(OSError, MOCK_ANY, MOCK_ANY))
460
461    @mock.patch('os.read')
462    def test_pause_reading(self, m_read):
463        tr = self.read_pipe_transport()
464        m = mock.Mock()
465        self.loop.add_reader(5, m)
466        tr.pause_reading()
467        self.assertFalse(self.loop.readers)
468
469    @mock.patch('os.read')
470    def test_resume_reading(self, m_read):
471        tr = self.read_pipe_transport()
472        tr.resume_reading()
473        self.loop.assert_reader(5, tr._read_ready)
474
475    @mock.patch('os.read')
476    def test_close(self, m_read):
477        tr = self.read_pipe_transport()
478        tr._close = mock.Mock()
479        tr.close()
480        tr._close.assert_called_with(None)
481
482    @mock.patch('os.read')
483    def test_close_already_closing(self, m_read):
484        tr = self.read_pipe_transport()
485        tr._closing = True
486        tr._close = mock.Mock()
487        tr.close()
488        self.assertFalse(tr._close.called)
489
490    @mock.patch('os.read')
491    def test__close(self, m_read):
492        tr = self.read_pipe_transport()
493        err = object()
494        tr._close(err)
495        self.assertTrue(tr.is_closing())
496        self.assertFalse(self.loop.readers)
497        test_utils.run_briefly(self.loop)
498        self.protocol.connection_lost.assert_called_with(err)
499
500    def test__call_connection_lost(self):
501        tr = self.read_pipe_transport()
502        self.assertIsNotNone(tr._protocol)
503        self.assertIsNotNone(tr._loop)
504
505        err = None
506        tr._call_connection_lost(err)
507        self.protocol.connection_lost.assert_called_with(err)
508        self.pipe.close.assert_called_with()
509
510        self.assertIsNone(tr._protocol)
511        self.assertIsNone(tr._loop)
512
513    def test__call_connection_lost_with_err(self):
514        tr = self.read_pipe_transport()
515        self.assertIsNotNone(tr._protocol)
516        self.assertIsNotNone(tr._loop)
517
518        err = OSError()
519        tr._call_connection_lost(err)
520        self.protocol.connection_lost.assert_called_with(err)
521        self.pipe.close.assert_called_with()
522
523        self.assertIsNone(tr._protocol)
524        self.assertIsNone(tr._loop)
525
526
527class UnixWritePipeTransportTests(test_utils.TestCase):
528
529    def setUp(self):
530        super().setUp()
531        self.loop = self.new_test_loop()
532        self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
533        self.pipe = mock.Mock(spec_set=io.RawIOBase)
534        self.pipe.fileno.return_value = 5
535
536        blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking')
537        blocking_patcher.start()
538        self.addCleanup(blocking_patcher.stop)
539
540        fstat_patcher = mock.patch('os.fstat')
541        m_fstat = fstat_patcher.start()
542        st = mock.Mock()
543        st.st_mode = stat.S_IFSOCK
544        m_fstat.return_value = st
545        self.addCleanup(fstat_patcher.stop)
546
547    def write_pipe_transport(self, waiter=None):
548        transport = unix_events._UnixWritePipeTransport(self.loop, self.pipe,
549                                                        self.protocol,
550                                                        waiter=waiter)
551        self.addCleanup(close_pipe_transport, transport)
552        return transport
553
554    def test_ctor(self):
555        waiter = asyncio.Future(loop=self.loop)
556        tr = self.write_pipe_transport(waiter=waiter)
557        self.loop.run_until_complete(waiter)
558
559        self.protocol.connection_made.assert_called_with(tr)
560        self.loop.assert_reader(5, tr._read_ready)
561        self.assertEqual(None, waiter.result())
562
563    def test_can_write_eof(self):
564        tr = self.write_pipe_transport()
565        self.assertTrue(tr.can_write_eof())
566
567    @mock.patch('os.write')
568    def test_write(self, m_write):
569        tr = self.write_pipe_transport()
570        m_write.return_value = 4
571        tr.write(b'data')
572        m_write.assert_called_with(5, b'data')
573        self.assertFalse(self.loop.writers)
574        self.assertEqual(bytearray(), tr._buffer)
575
576    @mock.patch('os.write')
577    def test_write_no_data(self, m_write):
578        tr = self.write_pipe_transport()
579        tr.write(b'')
580        self.assertFalse(m_write.called)
581        self.assertFalse(self.loop.writers)
582        self.assertEqual(bytearray(b''), tr._buffer)
583
584    @mock.patch('os.write')
585    def test_write_partial(self, m_write):
586        tr = self.write_pipe_transport()
587        m_write.return_value = 2
588        tr.write(b'data')
589        self.loop.assert_writer(5, tr._write_ready)
590        self.assertEqual(bytearray(b'ta'), tr._buffer)
591
592    @mock.patch('os.write')
593    def test_write_buffer(self, m_write):
594        tr = self.write_pipe_transport()
595        self.loop.add_writer(5, tr._write_ready)
596        tr._buffer = bytearray(b'previous')
597        tr.write(b'data')
598        self.assertFalse(m_write.called)
599        self.loop.assert_writer(5, tr._write_ready)
600        self.assertEqual(bytearray(b'previousdata'), tr._buffer)
601
602    @mock.patch('os.write')
603    def test_write_again(self, m_write):
604        tr = self.write_pipe_transport()
605        m_write.side_effect = BlockingIOError()
606        tr.write(b'data')
607        m_write.assert_called_with(5, bytearray(b'data'))
608        self.loop.assert_writer(5, tr._write_ready)
609        self.assertEqual(bytearray(b'data'), tr._buffer)
610
611    @mock.patch('asyncio.unix_events.logger')
612    @mock.patch('os.write')
613    def test_write_err(self, m_write, m_log):
614        tr = self.write_pipe_transport()
615        err = OSError()
616        m_write.side_effect = err
617        tr._fatal_error = mock.Mock()
618        tr.write(b'data')
619        m_write.assert_called_with(5, b'data')
620        self.assertFalse(self.loop.writers)
621        self.assertEqual(bytearray(), tr._buffer)
622        tr._fatal_error.assert_called_with(
623                            err,
624                            'Fatal write error on pipe transport')
625        self.assertEqual(1, tr._conn_lost)
626
627        tr.write(b'data')
628        self.assertEqual(2, tr._conn_lost)
629        tr.write(b'data')
630        tr.write(b'data')
631        tr.write(b'data')
632        tr.write(b'data')
633        # This is a bit overspecified. :-(
634        m_log.warning.assert_called_with(
635            'pipe closed by peer or os.write(pipe, data) raised exception.')
636        tr.close()
637
638    @mock.patch('os.write')
639    def test_write_close(self, m_write):
640        tr = self.write_pipe_transport()
641        tr._read_ready()  # pipe was closed by peer
642
643        tr.write(b'data')
644        self.assertEqual(tr._conn_lost, 1)
645        tr.write(b'data')
646        self.assertEqual(tr._conn_lost, 2)
647
648    def test__read_ready(self):
649        tr = self.write_pipe_transport()
650        tr._read_ready()
651        self.assertFalse(self.loop.readers)
652        self.assertFalse(self.loop.writers)
653        self.assertTrue(tr.is_closing())
654        test_utils.run_briefly(self.loop)
655        self.protocol.connection_lost.assert_called_with(None)
656
657    @mock.patch('os.write')
658    def test__write_ready(self, m_write):
659        tr = self.write_pipe_transport()
660        self.loop.add_writer(5, tr._write_ready)
661        tr._buffer = bytearray(b'data')
662        m_write.return_value = 4
663        tr._write_ready()
664        self.assertFalse(self.loop.writers)
665        self.assertEqual(bytearray(), tr._buffer)
666
667    @mock.patch('os.write')
668    def test__write_ready_partial(self, m_write):
669        tr = self.write_pipe_transport()
670        self.loop.add_writer(5, tr._write_ready)
671        tr._buffer = bytearray(b'data')
672        m_write.return_value = 3
673        tr._write_ready()
674        self.loop.assert_writer(5, tr._write_ready)
675        self.assertEqual(bytearray(b'a'), tr._buffer)
676
677    @mock.patch('os.write')
678    def test__write_ready_again(self, m_write):
679        tr = self.write_pipe_transport()
680        self.loop.add_writer(5, tr._write_ready)
681        tr._buffer = bytearray(b'data')
682        m_write.side_effect = BlockingIOError()
683        tr._write_ready()
684        m_write.assert_called_with(5, bytearray(b'data'))
685        self.loop.assert_writer(5, tr._write_ready)
686        self.assertEqual(bytearray(b'data'), tr._buffer)
687
688    @mock.patch('os.write')
689    def test__write_ready_empty(self, m_write):
690        tr = self.write_pipe_transport()
691        self.loop.add_writer(5, tr._write_ready)
692        tr._buffer = bytearray(b'data')
693        m_write.return_value = 0
694        tr._write_ready()
695        m_write.assert_called_with(5, bytearray(b'data'))
696        self.loop.assert_writer(5, tr._write_ready)
697        self.assertEqual(bytearray(b'data'), tr._buffer)
698
699    @mock.patch('asyncio.log.logger.error')
700    @mock.patch('os.write')
701    def test__write_ready_err(self, m_write, m_logexc):
702        tr = self.write_pipe_transport()
703        self.loop.add_writer(5, tr._write_ready)
704        tr._buffer = bytearray(b'data')
705        m_write.side_effect = err = OSError()
706        tr._write_ready()
707        self.assertFalse(self.loop.writers)
708        self.assertFalse(self.loop.readers)
709        self.assertEqual(bytearray(), tr._buffer)
710        self.assertTrue(tr.is_closing())
711        m_logexc.assert_called_with(
712            test_utils.MockPattern(
713                'Fatal write error on pipe transport'
714                '\nprotocol:.*\ntransport:.*'),
715            exc_info=(OSError, MOCK_ANY, MOCK_ANY))
716        self.assertEqual(1, tr._conn_lost)
717        test_utils.run_briefly(self.loop)
718        self.protocol.connection_lost.assert_called_with(err)
719
720    @mock.patch('os.write')
721    def test__write_ready_closing(self, m_write):
722        tr = self.write_pipe_transport()
723        self.loop.add_writer(5, tr._write_ready)
724        tr._closing = True
725        tr._buffer = bytearray(b'data')
726        m_write.return_value = 4
727        tr._write_ready()
728        self.assertFalse(self.loop.writers)
729        self.assertFalse(self.loop.readers)
730        self.assertEqual(bytearray(), tr._buffer)
731        self.protocol.connection_lost.assert_called_with(None)
732        self.pipe.close.assert_called_with()
733
734    @mock.patch('os.write')
735    def test_abort(self, m_write):
736        tr = self.write_pipe_transport()
737        self.loop.add_writer(5, tr._write_ready)
738        self.loop.add_reader(5, tr._read_ready)
739        tr._buffer = [b'da', b'ta']
740        tr.abort()
741        self.assertFalse(m_write.called)
742        self.assertFalse(self.loop.readers)
743        self.assertFalse(self.loop.writers)
744        self.assertEqual([], tr._buffer)
745        self.assertTrue(tr.is_closing())
746        test_utils.run_briefly(self.loop)
747        self.protocol.connection_lost.assert_called_with(None)
748
749    def test__call_connection_lost(self):
750        tr = self.write_pipe_transport()
751        self.assertIsNotNone(tr._protocol)
752        self.assertIsNotNone(tr._loop)
753
754        err = None
755        tr._call_connection_lost(err)
756        self.protocol.connection_lost.assert_called_with(err)
757        self.pipe.close.assert_called_with()
758
759        self.assertIsNone(tr._protocol)
760        self.assertIsNone(tr._loop)
761
762    def test__call_connection_lost_with_err(self):
763        tr = self.write_pipe_transport()
764        self.assertIsNotNone(tr._protocol)
765        self.assertIsNotNone(tr._loop)
766
767        err = OSError()
768        tr._call_connection_lost(err)
769        self.protocol.connection_lost.assert_called_with(err)
770        self.pipe.close.assert_called_with()
771
772        self.assertIsNone(tr._protocol)
773        self.assertIsNone(tr._loop)
774
775    def test_close(self):
776        tr = self.write_pipe_transport()
777        tr.write_eof = mock.Mock()
778        tr.close()
779        tr.write_eof.assert_called_with()
780
781        # closing the transport twice must not fail
782        tr.close()
783
784    def test_close_closing(self):
785        tr = self.write_pipe_transport()
786        tr.write_eof = mock.Mock()
787        tr._closing = True
788        tr.close()
789        self.assertFalse(tr.write_eof.called)
790
791    def test_write_eof(self):
792        tr = self.write_pipe_transport()
793        tr.write_eof()
794        self.assertTrue(tr.is_closing())
795        self.assertFalse(self.loop.readers)
796        test_utils.run_briefly(self.loop)
797        self.protocol.connection_lost.assert_called_with(None)
798
799    def test_write_eof_pending(self):
800        tr = self.write_pipe_transport()
801        tr._buffer = [b'data']
802        tr.write_eof()
803        self.assertTrue(tr.is_closing())
804        self.assertFalse(self.protocol.connection_lost.called)
805
806
807class AbstractChildWatcherTests(unittest.TestCase):
808
809    def test_not_implemented(self):
810        f = mock.Mock()
811        watcher = asyncio.AbstractChildWatcher()
812        self.assertRaises(
813            NotImplementedError, watcher.add_child_handler, f, f)
814        self.assertRaises(
815            NotImplementedError, watcher.remove_child_handler, f)
816        self.assertRaises(
817            NotImplementedError, watcher.attach_loop, f)
818        self.assertRaises(
819            NotImplementedError, watcher.close)
820        self.assertRaises(
821            NotImplementedError, watcher.__enter__)
822        self.assertRaises(
823            NotImplementedError, watcher.__exit__, f, f, f)
824
825
826class BaseChildWatcherTests(unittest.TestCase):
827
828    def test_not_implemented(self):
829        f = mock.Mock()
830        watcher = unix_events.BaseChildWatcher()
831        self.assertRaises(
832            NotImplementedError, watcher._do_waitpid, f)
833
834
835WaitPidMocks = collections.namedtuple("WaitPidMocks",
836                                      ("waitpid",
837                                       "WIFEXITED",
838                                       "WIFSIGNALED",
839                                       "WEXITSTATUS",
840                                       "WTERMSIG",
841                                       ))
842
843
844class ChildWatcherTestsMixin:
845
846    ignore_warnings = mock.patch.object(log.logger, "warning")
847
848    def setUp(self):
849        super().setUp()
850        self.loop = self.new_test_loop()
851        self.running = False
852        self.zombies = {}
853
854        with mock.patch.object(
855                self.loop, "add_signal_handler") as self.m_add_signal_handler:
856            self.watcher = self.create_watcher()
857            self.watcher.attach_loop(self.loop)
858
859    def waitpid(self, pid, flags):
860        if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1:
861            self.assertGreater(pid, 0)
862        try:
863            if pid < 0:
864                return self.zombies.popitem()
865            else:
866                return pid, self.zombies.pop(pid)
867        except KeyError:
868            pass
869        if self.running:
870            return 0, 0
871        else:
872            raise ChildProcessError()
873
874    def add_zombie(self, pid, returncode):
875        self.zombies[pid] = returncode + 32768
876
877    def WIFEXITED(self, status):
878        return status >= 32768
879
880    def WIFSIGNALED(self, status):
881        return 32700 < status < 32768
882
883    def WEXITSTATUS(self, status):
884        self.assertTrue(self.WIFEXITED(status))
885        return status - 32768
886
887    def WTERMSIG(self, status):
888        self.assertTrue(self.WIFSIGNALED(status))
889        return 32768 - status
890
891    def test_create_watcher(self):
892        self.m_add_signal_handler.assert_called_once_with(
893            signal.SIGCHLD, self.watcher._sig_chld)
894
895    def waitpid_mocks(func):
896        def wrapped_func(self):
897            def patch(target, wrapper):
898                return mock.patch(target, wraps=wrapper,
899                                  new_callable=mock.Mock)
900
901            with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \
902                 patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \
903                 patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \
904                 patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \
905                 patch('os.waitpid', self.waitpid) as m_waitpid:
906                func(self, WaitPidMocks(m_waitpid,
907                                        m_WIFEXITED, m_WIFSIGNALED,
908                                        m_WEXITSTATUS, m_WTERMSIG,
909                                        ))
910        return wrapped_func
911
912    @waitpid_mocks
913    def test_sigchld(self, m):
914        # register a child
915        callback = mock.Mock()
916
917        with self.watcher:
918            self.running = True
919            self.watcher.add_child_handler(42, callback, 9, 10, 14)
920
921        self.assertFalse(callback.called)
922        self.assertFalse(m.WIFEXITED.called)
923        self.assertFalse(m.WIFSIGNALED.called)
924        self.assertFalse(m.WEXITSTATUS.called)
925        self.assertFalse(m.WTERMSIG.called)
926
927        # child is running
928        self.watcher._sig_chld()
929
930        self.assertFalse(callback.called)
931        self.assertFalse(m.WIFEXITED.called)
932        self.assertFalse(m.WIFSIGNALED.called)
933        self.assertFalse(m.WEXITSTATUS.called)
934        self.assertFalse(m.WTERMSIG.called)
935
936        # child terminates (returncode 12)
937        self.running = False
938        self.add_zombie(42, 12)
939        self.watcher._sig_chld()
940
941        self.assertTrue(m.WIFEXITED.called)
942        self.assertTrue(m.WEXITSTATUS.called)
943        self.assertFalse(m.WTERMSIG.called)
944        callback.assert_called_once_with(42, 12, 9, 10, 14)
945
946        m.WIFSIGNALED.reset_mock()
947        m.WIFEXITED.reset_mock()
948        m.WEXITSTATUS.reset_mock()
949        callback.reset_mock()
950
951        # ensure that the child is effectively reaped
952        self.add_zombie(42, 13)
953        with self.ignore_warnings:
954            self.watcher._sig_chld()
955
956        self.assertFalse(callback.called)
957        self.assertFalse(m.WTERMSIG.called)
958
959        m.WIFSIGNALED.reset_mock()
960        m.WIFEXITED.reset_mock()
961        m.WEXITSTATUS.reset_mock()
962
963        # sigchld called again
964        self.zombies.clear()
965        self.watcher._sig_chld()
966
967        self.assertFalse(callback.called)
968        self.assertFalse(m.WIFEXITED.called)
969        self.assertFalse(m.WIFSIGNALED.called)
970        self.assertFalse(m.WEXITSTATUS.called)
971        self.assertFalse(m.WTERMSIG.called)
972
973    @waitpid_mocks
974    def test_sigchld_two_children(self, m):
975        callback1 = mock.Mock()
976        callback2 = mock.Mock()
977
978        # register child 1
979        with self.watcher:
980            self.running = True
981            self.watcher.add_child_handler(43, callback1, 7, 8)
982
983        self.assertFalse(callback1.called)
984        self.assertFalse(callback2.called)
985        self.assertFalse(m.WIFEXITED.called)
986        self.assertFalse(m.WIFSIGNALED.called)
987        self.assertFalse(m.WEXITSTATUS.called)
988        self.assertFalse(m.WTERMSIG.called)
989
990        # register child 2
991        with self.watcher:
992            self.watcher.add_child_handler(44, callback2, 147, 18)
993
994        self.assertFalse(callback1.called)
995        self.assertFalse(callback2.called)
996        self.assertFalse(m.WIFEXITED.called)
997        self.assertFalse(m.WIFSIGNALED.called)
998        self.assertFalse(m.WEXITSTATUS.called)
999        self.assertFalse(m.WTERMSIG.called)
1000
1001        # children are running
1002        self.watcher._sig_chld()
1003
1004        self.assertFalse(callback1.called)
1005        self.assertFalse(callback2.called)
1006        self.assertFalse(m.WIFEXITED.called)
1007        self.assertFalse(m.WIFSIGNALED.called)
1008        self.assertFalse(m.WEXITSTATUS.called)
1009        self.assertFalse(m.WTERMSIG.called)
1010
1011        # child 1 terminates (signal 3)
1012        self.add_zombie(43, -3)
1013        self.watcher._sig_chld()
1014
1015        callback1.assert_called_once_with(43, -3, 7, 8)
1016        self.assertFalse(callback2.called)
1017        self.assertTrue(m.WIFSIGNALED.called)
1018        self.assertFalse(m.WEXITSTATUS.called)
1019        self.assertTrue(m.WTERMSIG.called)
1020
1021        m.WIFSIGNALED.reset_mock()
1022        m.WIFEXITED.reset_mock()
1023        m.WTERMSIG.reset_mock()
1024        callback1.reset_mock()
1025
1026        # child 2 still running
1027        self.watcher._sig_chld()
1028
1029        self.assertFalse(callback1.called)
1030        self.assertFalse(callback2.called)
1031        self.assertFalse(m.WIFEXITED.called)
1032        self.assertFalse(m.WIFSIGNALED.called)
1033        self.assertFalse(m.WEXITSTATUS.called)
1034        self.assertFalse(m.WTERMSIG.called)
1035
1036        # child 2 terminates (code 108)
1037        self.add_zombie(44, 108)
1038        self.running = False
1039        self.watcher._sig_chld()
1040
1041        callback2.assert_called_once_with(44, 108, 147, 18)
1042        self.assertFalse(callback1.called)
1043        self.assertTrue(m.WIFEXITED.called)
1044        self.assertTrue(m.WEXITSTATUS.called)
1045        self.assertFalse(m.WTERMSIG.called)
1046
1047        m.WIFSIGNALED.reset_mock()
1048        m.WIFEXITED.reset_mock()
1049        m.WEXITSTATUS.reset_mock()
1050        callback2.reset_mock()
1051
1052        # ensure that the children are effectively reaped
1053        self.add_zombie(43, 14)
1054        self.add_zombie(44, 15)
1055        with self.ignore_warnings:
1056            self.watcher._sig_chld()
1057
1058        self.assertFalse(callback1.called)
1059        self.assertFalse(callback2.called)
1060        self.assertFalse(m.WTERMSIG.called)
1061
1062        m.WIFSIGNALED.reset_mock()
1063        m.WIFEXITED.reset_mock()
1064        m.WEXITSTATUS.reset_mock()
1065
1066        # sigchld called again
1067        self.zombies.clear()
1068        self.watcher._sig_chld()
1069
1070        self.assertFalse(callback1.called)
1071        self.assertFalse(callback2.called)
1072        self.assertFalse(m.WIFEXITED.called)
1073        self.assertFalse(m.WIFSIGNALED.called)
1074        self.assertFalse(m.WEXITSTATUS.called)
1075        self.assertFalse(m.WTERMSIG.called)
1076
1077    @waitpid_mocks
1078    def test_sigchld_two_children_terminating_together(self, m):
1079        callback1 = mock.Mock()
1080        callback2 = mock.Mock()
1081
1082        # register child 1
1083        with self.watcher:
1084            self.running = True
1085            self.watcher.add_child_handler(45, callback1, 17, 8)
1086
1087        self.assertFalse(callback1.called)
1088        self.assertFalse(callback2.called)
1089        self.assertFalse(m.WIFEXITED.called)
1090        self.assertFalse(m.WIFSIGNALED.called)
1091        self.assertFalse(m.WEXITSTATUS.called)
1092        self.assertFalse(m.WTERMSIG.called)
1093
1094        # register child 2
1095        with self.watcher:
1096            self.watcher.add_child_handler(46, callback2, 1147, 18)
1097
1098        self.assertFalse(callback1.called)
1099        self.assertFalse(callback2.called)
1100        self.assertFalse(m.WIFEXITED.called)
1101        self.assertFalse(m.WIFSIGNALED.called)
1102        self.assertFalse(m.WEXITSTATUS.called)
1103        self.assertFalse(m.WTERMSIG.called)
1104
1105        # children are running
1106        self.watcher._sig_chld()
1107
1108        self.assertFalse(callback1.called)
1109        self.assertFalse(callback2.called)
1110        self.assertFalse(m.WIFEXITED.called)
1111        self.assertFalse(m.WIFSIGNALED.called)
1112        self.assertFalse(m.WEXITSTATUS.called)
1113        self.assertFalse(m.WTERMSIG.called)
1114
1115        # child 1 terminates (code 78)
1116        # child 2 terminates (signal 5)
1117        self.add_zombie(45, 78)
1118        self.add_zombie(46, -5)
1119        self.running = False
1120        self.watcher._sig_chld()
1121
1122        callback1.assert_called_once_with(45, 78, 17, 8)
1123        callback2.assert_called_once_with(46, -5, 1147, 18)
1124        self.assertTrue(m.WIFSIGNALED.called)
1125        self.assertTrue(m.WIFEXITED.called)
1126        self.assertTrue(m.WEXITSTATUS.called)
1127        self.assertTrue(m.WTERMSIG.called)
1128
1129        m.WIFSIGNALED.reset_mock()
1130        m.WIFEXITED.reset_mock()
1131        m.WTERMSIG.reset_mock()
1132        m.WEXITSTATUS.reset_mock()
1133        callback1.reset_mock()
1134        callback2.reset_mock()
1135
1136        # ensure that the children are effectively reaped
1137        self.add_zombie(45, 14)
1138        self.add_zombie(46, 15)
1139        with self.ignore_warnings:
1140            self.watcher._sig_chld()
1141
1142        self.assertFalse(callback1.called)
1143        self.assertFalse(callback2.called)
1144        self.assertFalse(m.WTERMSIG.called)
1145
1146    @waitpid_mocks
1147    def test_sigchld_race_condition(self, m):
1148        # register a child
1149        callback = mock.Mock()
1150
1151        with self.watcher:
1152            # child terminates before being registered
1153            self.add_zombie(50, 4)
1154            self.watcher._sig_chld()
1155
1156            self.watcher.add_child_handler(50, callback, 1, 12)
1157
1158        callback.assert_called_once_with(50, 4, 1, 12)
1159        callback.reset_mock()
1160
1161        # ensure that the child is effectively reaped
1162        self.add_zombie(50, -1)
1163        with self.ignore_warnings:
1164            self.watcher._sig_chld()
1165
1166        self.assertFalse(callback.called)
1167
1168    @waitpid_mocks
1169    def test_sigchld_replace_handler(self, m):
1170        callback1 = mock.Mock()
1171        callback2 = mock.Mock()
1172
1173        # register a child
1174        with self.watcher:
1175            self.running = True
1176            self.watcher.add_child_handler(51, callback1, 19)
1177
1178        self.assertFalse(callback1.called)
1179        self.assertFalse(callback2.called)
1180        self.assertFalse(m.WIFEXITED.called)
1181        self.assertFalse(m.WIFSIGNALED.called)
1182        self.assertFalse(m.WEXITSTATUS.called)
1183        self.assertFalse(m.WTERMSIG.called)
1184
1185        # register the same child again
1186        with self.watcher:
1187            self.watcher.add_child_handler(51, callback2, 21)
1188
1189        self.assertFalse(callback1.called)
1190        self.assertFalse(callback2.called)
1191        self.assertFalse(m.WIFEXITED.called)
1192        self.assertFalse(m.WIFSIGNALED.called)
1193        self.assertFalse(m.WEXITSTATUS.called)
1194        self.assertFalse(m.WTERMSIG.called)
1195
1196        # child terminates (signal 8)
1197        self.running = False
1198        self.add_zombie(51, -8)
1199        self.watcher._sig_chld()
1200
1201        callback2.assert_called_once_with(51, -8, 21)
1202        self.assertFalse(callback1.called)
1203        self.assertTrue(m.WIFSIGNALED.called)
1204        self.assertFalse(m.WEXITSTATUS.called)
1205        self.assertTrue(m.WTERMSIG.called)
1206
1207        m.WIFSIGNALED.reset_mock()
1208        m.WIFEXITED.reset_mock()
1209        m.WTERMSIG.reset_mock()
1210        callback2.reset_mock()
1211
1212        # ensure that the child is effectively reaped
1213        self.add_zombie(51, 13)
1214        with self.ignore_warnings:
1215            self.watcher._sig_chld()
1216
1217        self.assertFalse(callback1.called)
1218        self.assertFalse(callback2.called)
1219        self.assertFalse(m.WTERMSIG.called)
1220
1221    @waitpid_mocks
1222    def test_sigchld_remove_handler(self, m):
1223        callback = mock.Mock()
1224
1225        # register a child
1226        with self.watcher:
1227            self.running = True
1228            self.watcher.add_child_handler(52, callback, 1984)
1229
1230        self.assertFalse(callback.called)
1231        self.assertFalse(m.WIFEXITED.called)
1232        self.assertFalse(m.WIFSIGNALED.called)
1233        self.assertFalse(m.WEXITSTATUS.called)
1234        self.assertFalse(m.WTERMSIG.called)
1235
1236        # unregister the child
1237        self.watcher.remove_child_handler(52)
1238
1239        self.assertFalse(callback.called)
1240        self.assertFalse(m.WIFEXITED.called)
1241        self.assertFalse(m.WIFSIGNALED.called)
1242        self.assertFalse(m.WEXITSTATUS.called)
1243        self.assertFalse(m.WTERMSIG.called)
1244
1245        # child terminates (code 99)
1246        self.running = False
1247        self.add_zombie(52, 99)
1248        with self.ignore_warnings:
1249            self.watcher._sig_chld()
1250
1251        self.assertFalse(callback.called)
1252
1253    @waitpid_mocks
1254    def test_sigchld_unknown_status(self, m):
1255        callback = mock.Mock()
1256
1257        # register a child
1258        with self.watcher:
1259            self.running = True
1260            self.watcher.add_child_handler(53, callback, -19)
1261
1262        self.assertFalse(callback.called)
1263        self.assertFalse(m.WIFEXITED.called)
1264        self.assertFalse(m.WIFSIGNALED.called)
1265        self.assertFalse(m.WEXITSTATUS.called)
1266        self.assertFalse(m.WTERMSIG.called)
1267
1268        # terminate with unknown status
1269        self.zombies[53] = 1178
1270        self.running = False
1271        self.watcher._sig_chld()
1272
1273        callback.assert_called_once_with(53, 1178, -19)
1274        self.assertTrue(m.WIFEXITED.called)
1275        self.assertTrue(m.WIFSIGNALED.called)
1276        self.assertFalse(m.WEXITSTATUS.called)
1277        self.assertFalse(m.WTERMSIG.called)
1278
1279        callback.reset_mock()
1280        m.WIFEXITED.reset_mock()
1281        m.WIFSIGNALED.reset_mock()
1282
1283        # ensure that the child is effectively reaped
1284        self.add_zombie(53, 101)
1285        with self.ignore_warnings:
1286            self.watcher._sig_chld()
1287
1288        self.assertFalse(callback.called)
1289
1290    @waitpid_mocks
1291    def test_remove_child_handler(self, m):
1292        callback1 = mock.Mock()
1293        callback2 = mock.Mock()
1294        callback3 = mock.Mock()
1295
1296        # register children
1297        with self.watcher:
1298            self.running = True
1299            self.watcher.add_child_handler(54, callback1, 1)
1300            self.watcher.add_child_handler(55, callback2, 2)
1301            self.watcher.add_child_handler(56, callback3, 3)
1302
1303        # remove child handler 1
1304        self.assertTrue(self.watcher.remove_child_handler(54))
1305
1306        # remove child handler 2 multiple times
1307        self.assertTrue(self.watcher.remove_child_handler(55))
1308        self.assertFalse(self.watcher.remove_child_handler(55))
1309        self.assertFalse(self.watcher.remove_child_handler(55))
1310
1311        # all children terminate
1312        self.add_zombie(54, 0)
1313        self.add_zombie(55, 1)
1314        self.add_zombie(56, 2)
1315        self.running = False
1316        with self.ignore_warnings:
1317            self.watcher._sig_chld()
1318
1319        self.assertFalse(callback1.called)
1320        self.assertFalse(callback2.called)
1321        callback3.assert_called_once_with(56, 2, 3)
1322
1323    @waitpid_mocks
1324    def test_sigchld_unhandled_exception(self, m):
1325        callback = mock.Mock()
1326
1327        # register a child
1328        with self.watcher:
1329            self.running = True
1330            self.watcher.add_child_handler(57, callback)
1331
1332        # raise an exception
1333        m.waitpid.side_effect = ValueError
1334
1335        with mock.patch.object(log.logger,
1336                               'error') as m_error:
1337
1338            self.assertEqual(self.watcher._sig_chld(), None)
1339            self.assertTrue(m_error.called)
1340
1341    @waitpid_mocks
1342    def test_sigchld_child_reaped_elsewhere(self, m):
1343        # register a child
1344        callback = mock.Mock()
1345
1346        with self.watcher:
1347            self.running = True
1348            self.watcher.add_child_handler(58, callback)
1349
1350        self.assertFalse(callback.called)
1351        self.assertFalse(m.WIFEXITED.called)
1352        self.assertFalse(m.WIFSIGNALED.called)
1353        self.assertFalse(m.WEXITSTATUS.called)
1354        self.assertFalse(m.WTERMSIG.called)
1355
1356        # child terminates
1357        self.running = False
1358        self.add_zombie(58, 4)
1359
1360        # waitpid is called elsewhere
1361        os.waitpid(58, os.WNOHANG)
1362
1363        m.waitpid.reset_mock()
1364
1365        # sigchld
1366        with self.ignore_warnings:
1367            self.watcher._sig_chld()
1368
1369        if isinstance(self.watcher, asyncio.FastChildWatcher):
1370            # here the FastChildWatche enters a deadlock
1371            # (there is no way to prevent it)
1372            self.assertFalse(callback.called)
1373        else:
1374            callback.assert_called_once_with(58, 255)
1375
1376    @waitpid_mocks
1377    def test_sigchld_unknown_pid_during_registration(self, m):
1378        # register two children
1379        callback1 = mock.Mock()
1380        callback2 = mock.Mock()
1381
1382        with self.ignore_warnings, self.watcher:
1383            self.running = True
1384            # child 1 terminates
1385            self.add_zombie(591, 7)
1386            # an unknown child terminates
1387            self.add_zombie(593, 17)
1388
1389            self.watcher._sig_chld()
1390
1391            self.watcher.add_child_handler(591, callback1)
1392            self.watcher.add_child_handler(592, callback2)
1393
1394        callback1.assert_called_once_with(591, 7)
1395        self.assertFalse(callback2.called)
1396
1397    @waitpid_mocks
1398    def test_set_loop(self, m):
1399        # register a child
1400        callback = mock.Mock()
1401
1402        with self.watcher:
1403            self.running = True
1404            self.watcher.add_child_handler(60, callback)
1405
1406        # attach a new loop
1407        old_loop = self.loop
1408        self.loop = self.new_test_loop()
1409        patch = mock.patch.object
1410
1411        with patch(old_loop, "remove_signal_handler") as m_old_remove, \
1412             patch(self.loop, "add_signal_handler") as m_new_add:
1413
1414            self.watcher.attach_loop(self.loop)
1415
1416            m_old_remove.assert_called_once_with(
1417                signal.SIGCHLD)
1418            m_new_add.assert_called_once_with(
1419                signal.SIGCHLD, self.watcher._sig_chld)
1420
1421        # child terminates
1422        self.running = False
1423        self.add_zombie(60, 9)
1424        self.watcher._sig_chld()
1425
1426        callback.assert_called_once_with(60, 9)
1427
1428    @waitpid_mocks
1429    def test_set_loop_race_condition(self, m):
1430        # register 3 children
1431        callback1 = mock.Mock()
1432        callback2 = mock.Mock()
1433        callback3 = mock.Mock()
1434
1435        with self.watcher:
1436            self.running = True
1437            self.watcher.add_child_handler(61, callback1)
1438            self.watcher.add_child_handler(62, callback2)
1439            self.watcher.add_child_handler(622, callback3)
1440
1441        # detach the loop
1442        old_loop = self.loop
1443        self.loop = None
1444
1445        with mock.patch.object(
1446                old_loop, "remove_signal_handler") as m_remove_signal_handler:
1447
1448            with self.assertWarnsRegex(
1449                    RuntimeWarning, 'A loop is being detached'):
1450                self.watcher.attach_loop(None)
1451
1452            m_remove_signal_handler.assert_called_once_with(
1453                signal.SIGCHLD)
1454
1455        # child 1 & 2 terminate
1456        self.add_zombie(61, 11)
1457        self.add_zombie(62, -5)
1458
1459        # SIGCHLD was not caught
1460        self.assertFalse(callback1.called)
1461        self.assertFalse(callback2.called)
1462        self.assertFalse(callback3.called)
1463
1464        # attach a new loop
1465        self.loop = self.new_test_loop()
1466
1467        with mock.patch.object(
1468                self.loop, "add_signal_handler") as m_add_signal_handler:
1469
1470            self.watcher.attach_loop(self.loop)
1471
1472            m_add_signal_handler.assert_called_once_with(
1473                signal.SIGCHLD, self.watcher._sig_chld)
1474            callback1.assert_called_once_with(61, 11)  # race condition!
1475            callback2.assert_called_once_with(62, -5)  # race condition!
1476            self.assertFalse(callback3.called)
1477
1478        callback1.reset_mock()
1479        callback2.reset_mock()
1480
1481        # child 3 terminates
1482        self.running = False
1483        self.add_zombie(622, 19)
1484        self.watcher._sig_chld()
1485
1486        self.assertFalse(callback1.called)
1487        self.assertFalse(callback2.called)
1488        callback3.assert_called_once_with(622, 19)
1489
1490    @waitpid_mocks
1491    def test_close(self, m):
1492        # register two children
1493        callback1 = mock.Mock()
1494
1495        with self.watcher:
1496            self.running = True
1497            # child 1 terminates
1498            self.add_zombie(63, 9)
1499            # other child terminates
1500            self.add_zombie(65, 18)
1501            self.watcher._sig_chld()
1502
1503            self.watcher.add_child_handler(63, callback1)
1504            self.watcher.add_child_handler(64, callback1)
1505
1506            self.assertEqual(len(self.watcher._callbacks), 1)
1507            if isinstance(self.watcher, asyncio.FastChildWatcher):
1508                self.assertEqual(len(self.watcher._zombies), 1)
1509
1510            with mock.patch.object(
1511                    self.loop,
1512                    "remove_signal_handler") as m_remove_signal_handler:
1513
1514                self.watcher.close()
1515
1516                m_remove_signal_handler.assert_called_once_with(
1517                    signal.SIGCHLD)
1518                self.assertFalse(self.watcher._callbacks)
1519                if isinstance(self.watcher, asyncio.FastChildWatcher):
1520                    self.assertFalse(self.watcher._zombies)
1521
1522    @waitpid_mocks
1523    def test_add_child_handler_with_no_loop_attached(self, m):
1524        callback = mock.Mock()
1525        with self.create_watcher() as watcher:
1526            with self.assertRaisesRegex(
1527                    RuntimeError,
1528                    'the child watcher does not have a loop attached'):
1529                watcher.add_child_handler(100, callback)
1530
1531
1532class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
1533    def create_watcher(self):
1534        return asyncio.SafeChildWatcher()
1535
1536
1537class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
1538    def create_watcher(self):
1539        return asyncio.FastChildWatcher()
1540
1541
1542class PolicyTests(unittest.TestCase):
1543
1544    def create_policy(self):
1545        return asyncio.DefaultEventLoopPolicy()
1546
1547    def test_get_child_watcher(self):
1548        policy = self.create_policy()
1549        self.assertIsNone(policy._watcher)
1550
1551        watcher = policy.get_child_watcher()
1552        self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
1553
1554        self.assertIs(policy._watcher, watcher)
1555
1556        self.assertIs(watcher, policy.get_child_watcher())
1557        self.assertIsNone(watcher._loop)
1558
1559    def test_get_child_watcher_after_set(self):
1560        policy = self.create_policy()
1561        watcher = asyncio.FastChildWatcher()
1562
1563        policy.set_child_watcher(watcher)
1564        self.assertIs(policy._watcher, watcher)
1565        self.assertIs(watcher, policy.get_child_watcher())
1566
1567    def test_get_child_watcher_with_mainloop_existing(self):
1568        policy = self.create_policy()
1569        loop = policy.get_event_loop()
1570
1571        self.assertIsNone(policy._watcher)
1572        watcher = policy.get_child_watcher()
1573
1574        self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
1575        self.assertIs(watcher._loop, loop)
1576
1577        loop.close()
1578
1579    def test_get_child_watcher_thread(self):
1580
1581        def f():
1582            policy.set_event_loop(policy.new_event_loop())
1583
1584            self.assertIsInstance(policy.get_event_loop(),
1585                                  asyncio.AbstractEventLoop)
1586            watcher = policy.get_child_watcher()
1587
1588            self.assertIsInstance(watcher, asyncio.SafeChildWatcher)
1589            self.assertIsNone(watcher._loop)
1590
1591            policy.get_event_loop().close()
1592
1593        policy = self.create_policy()
1594
1595        th = threading.Thread(target=f)
1596        th.start()
1597        th.join()
1598
1599    def test_child_watcher_replace_mainloop_existing(self):
1600        policy = self.create_policy()
1601        loop = policy.get_event_loop()
1602
1603        watcher = policy.get_child_watcher()
1604
1605        self.assertIs(watcher._loop, loop)
1606
1607        new_loop = policy.new_event_loop()
1608        policy.set_event_loop(new_loop)
1609
1610        self.assertIs(watcher._loop, new_loop)
1611
1612        policy.set_event_loop(None)
1613
1614        self.assertIs(watcher._loop, None)
1615
1616        loop.close()
1617        new_loop.close()
1618
1619
1620if __name__ == '__main__':
1621    unittest.main()
1622