• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import errno
2import os
3import random
4import selectors
5import signal
6import socket
7import sys
8from test import support
9from test.support import is_apple, os_helper, socket_helper
10from time import sleep
11import unittest
12import unittest.mock
13import tempfile
14from time import monotonic as time
15try:
16    import resource
17except ImportError:
18    resource = None
19
20
21if support.is_emscripten or support.is_wasi:
22    raise unittest.SkipTest("Cannot create socketpair on Emscripten/WASI.")
23
24
25if hasattr(socket, 'socketpair'):
26    socketpair = socket.socketpair
27else:
28    def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
29        with socket.socket(family, type, proto) as l:
30            l.bind((socket_helper.HOST, 0))
31            l.listen()
32            c = socket.socket(family, type, proto)
33            try:
34                c.connect(l.getsockname())
35                caddr = c.getsockname()
36                while True:
37                    a, addr = l.accept()
38                    # check that we've got the correct client
39                    if addr == caddr:
40                        return c, a
41                    a.close()
42            except OSError:
43                c.close()
44                raise
45
46
47def find_ready_matching(ready, flag):
48    match = []
49    for key, events in ready:
50        if events & flag:
51            match.append(key.fileobj)
52    return match
53
54
55class BaseSelectorTestCase:
56
57    def make_socketpair(self):
58        rd, wr = socketpair()
59        self.addCleanup(rd.close)
60        self.addCleanup(wr.close)
61        return rd, wr
62
63    def test_register(self):
64        s = self.SELECTOR()
65        self.addCleanup(s.close)
66
67        rd, wr = self.make_socketpair()
68
69        key = s.register(rd, selectors.EVENT_READ, "data")
70        self.assertIsInstance(key, selectors.SelectorKey)
71        self.assertEqual(key.fileobj, rd)
72        self.assertEqual(key.fd, rd.fileno())
73        self.assertEqual(key.events, selectors.EVENT_READ)
74        self.assertEqual(key.data, "data")
75
76        # register an unknown event
77        self.assertRaises(ValueError, s.register, 0, 999999)
78
79        # register an invalid FD
80        self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
81
82        # register twice
83        self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
84
85        # register the same FD, but with a different object
86        self.assertRaises(KeyError, s.register, rd.fileno(),
87                          selectors.EVENT_READ)
88
89    def test_unregister(self):
90        s = self.SELECTOR()
91        self.addCleanup(s.close)
92
93        rd, wr = self.make_socketpair()
94
95        s.register(rd, selectors.EVENT_READ)
96        s.unregister(rd)
97
98        # unregister an unknown file obj
99        self.assertRaises(KeyError, s.unregister, 999999)
100
101        # unregister twice
102        self.assertRaises(KeyError, s.unregister, rd)
103
104    def test_unregister_after_fd_close(self):
105        s = self.SELECTOR()
106        self.addCleanup(s.close)
107        rd, wr = self.make_socketpair()
108        r, w = rd.fileno(), wr.fileno()
109        s.register(r, selectors.EVENT_READ)
110        s.register(w, selectors.EVENT_WRITE)
111        rd.close()
112        wr.close()
113        s.unregister(r)
114        s.unregister(w)
115
116    @unittest.skipUnless(os.name == 'posix', "requires posix")
117    def test_unregister_after_fd_close_and_reuse(self):
118        s = self.SELECTOR()
119        self.addCleanup(s.close)
120        rd, wr = self.make_socketpair()
121        r, w = rd.fileno(), wr.fileno()
122        s.register(r, selectors.EVENT_READ)
123        s.register(w, selectors.EVENT_WRITE)
124        rd2, wr2 = self.make_socketpair()
125        rd.close()
126        wr.close()
127        os.dup2(rd2.fileno(), r)
128        os.dup2(wr2.fileno(), w)
129        self.addCleanup(os.close, r)
130        self.addCleanup(os.close, w)
131        s.unregister(r)
132        s.unregister(w)
133
134    def test_unregister_after_socket_close(self):
135        s = self.SELECTOR()
136        self.addCleanup(s.close)
137        rd, wr = self.make_socketpair()
138        s.register(rd, selectors.EVENT_READ)
139        s.register(wr, selectors.EVENT_WRITE)
140        rd.close()
141        wr.close()
142        s.unregister(rd)
143        s.unregister(wr)
144
145    def test_modify(self):
146        s = self.SELECTOR()
147        self.addCleanup(s.close)
148
149        rd, wr = self.make_socketpair()
150
151        key = s.register(rd, selectors.EVENT_READ)
152
153        # modify events
154        key2 = s.modify(rd, selectors.EVENT_WRITE)
155        self.assertNotEqual(key.events, key2.events)
156        self.assertEqual(key2, s.get_key(rd))
157
158        s.unregister(rd)
159
160        # modify data
161        d1 = object()
162        d2 = object()
163
164        key = s.register(rd, selectors.EVENT_READ, d1)
165        key2 = s.modify(rd, selectors.EVENT_READ, d2)
166        self.assertEqual(key.events, key2.events)
167        self.assertNotEqual(key.data, key2.data)
168        self.assertEqual(key2, s.get_key(rd))
169        self.assertEqual(key2.data, d2)
170
171        # modify unknown file obj
172        self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
173
174        # modify use a shortcut
175        d3 = object()
176        s.register = unittest.mock.Mock()
177        s.unregister = unittest.mock.Mock()
178
179        s.modify(rd, selectors.EVENT_READ, d3)
180        self.assertFalse(s.register.called)
181        self.assertFalse(s.unregister.called)
182
183    def test_modify_unregister(self):
184        # Make sure the fd is unregister()ed in case of error on
185        # modify(): http://bugs.python.org/issue30014
186        if self.SELECTOR.__name__ == 'EpollSelector':
187            patch = unittest.mock.patch(
188                'selectors.EpollSelector._selector_cls')
189        elif self.SELECTOR.__name__ == 'PollSelector':
190            patch = unittest.mock.patch(
191                'selectors.PollSelector._selector_cls')
192        elif self.SELECTOR.__name__ == 'DevpollSelector':
193            patch = unittest.mock.patch(
194                'selectors.DevpollSelector._selector_cls')
195        else:
196            raise self.skipTest("")
197
198        with patch as m:
199            m.return_value.modify = unittest.mock.Mock(
200                side_effect=ZeroDivisionError)
201            s = self.SELECTOR()
202            self.addCleanup(s.close)
203            rd, wr = self.make_socketpair()
204            s.register(rd, selectors.EVENT_READ)
205            self.assertEqual(len(s._map), 1)
206            with self.assertRaises(ZeroDivisionError):
207                s.modify(rd, selectors.EVENT_WRITE)
208            self.assertEqual(len(s._map), 0)
209
210    def test_close(self):
211        s = self.SELECTOR()
212        self.addCleanup(s.close)
213
214        mapping = s.get_map()
215        rd, wr = self.make_socketpair()
216
217        s.register(rd, selectors.EVENT_READ)
218        s.register(wr, selectors.EVENT_WRITE)
219
220        s.close()
221        self.assertRaises(RuntimeError, s.get_key, rd)
222        self.assertRaises(RuntimeError, s.get_key, wr)
223        self.assertRaises(KeyError, mapping.__getitem__, rd)
224        self.assertRaises(KeyError, mapping.__getitem__, wr)
225        self.assertEqual(mapping.get(rd), None)
226        self.assertEqual(mapping.get(wr), None)
227
228    def test_get_key(self):
229        s = self.SELECTOR()
230        self.addCleanup(s.close)
231
232        rd, wr = self.make_socketpair()
233
234        key = s.register(rd, selectors.EVENT_READ, "data")
235        self.assertEqual(key, s.get_key(rd))
236
237        # unknown file obj
238        self.assertRaises(KeyError, s.get_key, 999999)
239
240    def test_get_map(self):
241        s = self.SELECTOR()
242        self.addCleanup(s.close)
243
244        rd, wr = self.make_socketpair()
245        sentinel = object()
246
247        keys = s.get_map()
248        self.assertFalse(keys)
249        self.assertEqual(len(keys), 0)
250        self.assertEqual(list(keys), [])
251        self.assertEqual(keys.get(rd), None)
252        self.assertEqual(keys.get(rd, sentinel), sentinel)
253        key = s.register(rd, selectors.EVENT_READ, "data")
254        self.assertIn(rd, keys)
255        self.assertEqual(key, keys.get(rd))
256        self.assertEqual(key, keys[rd])
257        self.assertEqual(len(keys), 1)
258        self.assertEqual(list(keys), [rd.fileno()])
259        self.assertEqual(list(keys.values()), [key])
260
261        # unknown file obj
262        with self.assertRaises(KeyError):
263            keys[999999]
264
265        # Read-only mapping
266        with self.assertRaises(TypeError):
267            del keys[rd]
268
269    def test_select(self):
270        s = self.SELECTOR()
271        self.addCleanup(s.close)
272
273        rd, wr = self.make_socketpair()
274
275        s.register(rd, selectors.EVENT_READ)
276        wr_key = s.register(wr, selectors.EVENT_WRITE)
277
278        result = s.select()
279        for key, events in result:
280            self.assertTrue(isinstance(key, selectors.SelectorKey))
281            self.assertTrue(events)
282            self.assertFalse(events & ~(selectors.EVENT_READ |
283                                        selectors.EVENT_WRITE))
284
285        self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
286
287    def test_select_read_write(self):
288        # gh-110038: when a file descriptor is registered for both read and
289        # write, the two events must be seen on a single call to select().
290        s = self.SELECTOR()
291        self.addCleanup(s.close)
292
293        sock1, sock2 = self.make_socketpair()
294        sock2.send(b"foo")
295        my_key = s.register(sock1, selectors.EVENT_READ | selectors.EVENT_WRITE)
296
297        seen_read, seen_write = False, False
298        result = s.select()
299        # We get the read and write either in the same result entry or in two
300        # distinct entries with the same key.
301        self.assertLessEqual(len(result), 2)
302        for key, events in result:
303            self.assertTrue(isinstance(key, selectors.SelectorKey))
304            self.assertEqual(key, my_key)
305            self.assertFalse(events & ~(selectors.EVENT_READ |
306                                        selectors.EVENT_WRITE))
307            if events & selectors.EVENT_READ:
308                self.assertFalse(seen_read)
309                seen_read = True
310            if events & selectors.EVENT_WRITE:
311                self.assertFalse(seen_write)
312                seen_write = True
313        self.assertTrue(seen_read)
314        self.assertTrue(seen_write)
315
316    def test_context_manager(self):
317        s = self.SELECTOR()
318        self.addCleanup(s.close)
319
320        rd, wr = self.make_socketpair()
321
322        with s as sel:
323            sel.register(rd, selectors.EVENT_READ)
324            sel.register(wr, selectors.EVENT_WRITE)
325
326        self.assertRaises(RuntimeError, s.get_key, rd)
327        self.assertRaises(RuntimeError, s.get_key, wr)
328
329    def test_fileno(self):
330        s = self.SELECTOR()
331        self.addCleanup(s.close)
332
333        if hasattr(s, 'fileno'):
334            fd = s.fileno()
335            self.assertTrue(isinstance(fd, int))
336            self.assertGreaterEqual(fd, 0)
337
338    def test_selector(self):
339        s = self.SELECTOR()
340        self.addCleanup(s.close)
341
342        NUM_SOCKETS = 12
343        MSG = b" This is a test."
344        MSG_LEN = len(MSG)
345        readers = []
346        writers = []
347        r2w = {}
348        w2r = {}
349
350        for i in range(NUM_SOCKETS):
351            rd, wr = self.make_socketpair()
352            s.register(rd, selectors.EVENT_READ)
353            s.register(wr, selectors.EVENT_WRITE)
354            readers.append(rd)
355            writers.append(wr)
356            r2w[rd] = wr
357            w2r[wr] = rd
358
359        bufs = []
360
361        while writers:
362            ready = s.select()
363            ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
364            if not ready_writers:
365                self.fail("no sockets ready for writing")
366            wr = random.choice(ready_writers)
367            wr.send(MSG)
368
369            for i in range(10):
370                ready = s.select()
371                ready_readers = find_ready_matching(ready,
372                                                    selectors.EVENT_READ)
373                if ready_readers:
374                    break
375                # there might be a delay between the write to the write end and
376                # the read end is reported ready
377                sleep(0.1)
378            else:
379                self.fail("no sockets ready for reading")
380            self.assertEqual([w2r[wr]], ready_readers)
381            rd = ready_readers[0]
382            buf = rd.recv(MSG_LEN)
383            self.assertEqual(len(buf), MSG_LEN)
384            bufs.append(buf)
385            s.unregister(r2w[rd])
386            s.unregister(rd)
387            writers.remove(r2w[rd])
388
389        self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
390
391    @unittest.skipIf(sys.platform == 'win32',
392                     'select.select() cannot be used with empty fd sets')
393    def test_empty_select(self):
394        # Issue #23009: Make sure EpollSelector.select() works when no FD is
395        # registered.
396        s = self.SELECTOR()
397        self.addCleanup(s.close)
398        self.assertEqual(s.select(timeout=0), [])
399
400    def test_timeout(self):
401        s = self.SELECTOR()
402        self.addCleanup(s.close)
403
404        rd, wr = self.make_socketpair()
405
406        s.register(wr, selectors.EVENT_WRITE)
407        t = time()
408        self.assertEqual(1, len(s.select(0)))
409        self.assertEqual(1, len(s.select(-1)))
410        self.assertLess(time() - t, 0.5)
411
412        s.unregister(wr)
413        s.register(rd, selectors.EVENT_READ)
414        t = time()
415        self.assertFalse(s.select(0))
416        self.assertFalse(s.select(-1))
417        self.assertLess(time() - t, 0.5)
418
419        t0 = time()
420        self.assertFalse(s.select(1))
421        t1 = time()
422        dt = t1 - t0
423        # Tolerate 2.0 seconds for very slow buildbots
424        self.assertTrue(0.8 <= dt <= 2.0, dt)
425
426    @unittest.skipUnless(hasattr(signal, "alarm"),
427                         "signal.alarm() required for this test")
428    def test_select_interrupt_exc(self):
429        s = self.SELECTOR()
430        self.addCleanup(s.close)
431
432        rd, wr = self.make_socketpair()
433
434        class InterruptSelect(Exception):
435            pass
436
437        def handler(*args):
438            raise InterruptSelect
439
440        orig_alrm_handler = signal.signal(signal.SIGALRM, handler)
441        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
442
443        try:
444            signal.alarm(1)
445
446            s.register(rd, selectors.EVENT_READ)
447            t = time()
448            # select() is interrupted by a signal which raises an exception
449            with self.assertRaises(InterruptSelect):
450                s.select(30)
451            # select() was interrupted before the timeout of 30 seconds
452            self.assertLess(time() - t, 5.0)
453        finally:
454            signal.alarm(0)
455
456    @unittest.skipUnless(hasattr(signal, "alarm"),
457                         "signal.alarm() required for this test")
458    def test_select_interrupt_noraise(self):
459        s = self.SELECTOR()
460        self.addCleanup(s.close)
461
462        rd, wr = self.make_socketpair()
463
464        orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
465        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
466
467        try:
468            signal.alarm(1)
469
470            s.register(rd, selectors.EVENT_READ)
471            t = time()
472            # select() is interrupted by a signal, but the signal handler doesn't
473            # raise an exception, so select() should by retries with a recomputed
474            # timeout
475            self.assertFalse(s.select(1.5))
476            self.assertGreaterEqual(time() - t, 1.0)
477        finally:
478            signal.alarm(0)
479
480
481class ScalableSelectorMixIn:
482
483    # see issue #18963 for why it's skipped on older OS X versions
484    @support.requires_mac_ver(10, 5)
485    @unittest.skipUnless(resource, "Test needs resource module")
486    @support.requires_resource('cpu')
487    def test_above_fd_setsize(self):
488        # A scalable implementation should have no problem with more than
489        # FD_SETSIZE file descriptors. Since we don't know the value, we just
490        # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
491        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
492        try:
493            resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
494            self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
495                            (soft, hard))
496            NUM_FDS = min(hard, 2**16)
497        except (OSError, ValueError):
498            NUM_FDS = soft
499
500        # guard for already allocated FDs (stdin, stdout...)
501        NUM_FDS -= 32
502
503        s = self.SELECTOR()
504        self.addCleanup(s.close)
505
506        for i in range(NUM_FDS // 2):
507            try:
508                rd, wr = self.make_socketpair()
509            except OSError:
510                # too many FDs, skip - note that we should only catch EMFILE
511                # here, but apparently *BSD and Solaris can fail upon connect()
512                # or bind() with EADDRNOTAVAIL, so let's be safe
513                self.skipTest("FD limit reached")
514
515            try:
516                s.register(rd, selectors.EVENT_READ)
517                s.register(wr, selectors.EVENT_WRITE)
518            except OSError as e:
519                if e.errno == errno.ENOSPC:
520                    # this can be raised by epoll if we go over
521                    # fs.epoll.max_user_watches sysctl
522                    self.skipTest("FD limit reached")
523                raise
524
525        try:
526            fds = s.select()
527        except OSError as e:
528            if e.errno == errno.EINVAL and is_apple:
529                # unexplainable errors on macOS don't need to fail the test
530                self.skipTest("Invalid argument error calling poll()")
531            raise
532        self.assertEqual(NUM_FDS // 2, len(fds))
533
534
535class DefaultSelectorTestCase(BaseSelectorTestCase, unittest.TestCase):
536
537    SELECTOR = selectors.DefaultSelector
538
539
540class SelectSelectorTestCase(BaseSelectorTestCase, unittest.TestCase):
541
542    SELECTOR = selectors.SelectSelector
543
544
545@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
546                     "Test needs selectors.PollSelector")
547class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
548                           unittest.TestCase):
549
550    SELECTOR = getattr(selectors, 'PollSelector', None)
551
552
553@unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
554                     "Test needs selectors.EpollSelector")
555class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
556                            unittest.TestCase):
557
558    SELECTOR = getattr(selectors, 'EpollSelector', None)
559
560    def test_register_file(self):
561        # epoll(7) returns EPERM when given a file to watch
562        s = self.SELECTOR()
563        with tempfile.NamedTemporaryFile() as f:
564            with self.assertRaises(IOError):
565                s.register(f, selectors.EVENT_READ)
566            # the SelectorKey has been removed
567            with self.assertRaises(KeyError):
568                s.get_key(f)
569
570
571@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
572                     "Test needs selectors.KqueueSelector)")
573class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
574                             unittest.TestCase):
575
576    SELECTOR = getattr(selectors, 'KqueueSelector', None)
577
578    def test_register_bad_fd(self):
579        # a file descriptor that's been closed should raise an OSError
580        # with EBADF
581        s = self.SELECTOR()
582        bad_f = os_helper.make_bad_fd()
583        with self.assertRaises(OSError) as cm:
584            s.register(bad_f, selectors.EVENT_READ)
585        self.assertEqual(cm.exception.errno, errno.EBADF)
586        # the SelectorKey has been removed
587        with self.assertRaises(KeyError):
588            s.get_key(bad_f)
589
590    def test_empty_select_timeout(self):
591        # Issues #23009, #29255: Make sure timeout is applied when no fds
592        # are registered.
593        s = self.SELECTOR()
594        self.addCleanup(s.close)
595
596        t0 = time()
597        self.assertEqual(s.select(1), [])
598        t1 = time()
599        dt = t1 - t0
600        # Tolerate 2.0 seconds for very slow buildbots
601        self.assertTrue(0.8 <= dt <= 2.0, dt)
602
603
604@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'),
605                     "Test needs selectors.DevpollSelector")
606class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn,
607                              unittest.TestCase):
608
609    SELECTOR = getattr(selectors, 'DevpollSelector', None)
610
611
612def tearDownModule():
613    support.reap_children()
614
615
616if __name__ == "__main__":
617    unittest.main()
618