• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import errno
2import os
3import random
4import selectors
5import signal
6import socket
7import sys
8from test import support
9from time import sleep
10import unittest
11import unittest.mock
12import tempfile
13from time import monotonic as time
14try:
15    import resource
16except ImportError:
17    resource = None
18
19
20if hasattr(socket, 'socketpair'):
21    socketpair = socket.socketpair
22else:
23    def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
24        with socket.socket(family, type, proto) as l:
25            l.bind((support.HOST, 0))
26            l.listen()
27            c = socket.socket(family, type, proto)
28            try:
29                c.connect(l.getsockname())
30                caddr = c.getsockname()
31                while True:
32                    a, addr = l.accept()
33                    # check that we've got the correct client
34                    if addr == caddr:
35                        return c, a
36                    a.close()
37            except OSError:
38                c.close()
39                raise
40
41
42def find_ready_matching(ready, flag):
43    match = []
44    for key, events in ready:
45        if events & flag:
46            match.append(key.fileobj)
47    return match
48
49
50class BaseSelectorTestCase(unittest.TestCase):
51
52    def make_socketpair(self):
53        rd, wr = socketpair()
54        self.addCleanup(rd.close)
55        self.addCleanup(wr.close)
56        return rd, wr
57
58    def test_register(self):
59        s = self.SELECTOR()
60        self.addCleanup(s.close)
61
62        rd, wr = self.make_socketpair()
63
64        key = s.register(rd, selectors.EVENT_READ, "data")
65        self.assertIsInstance(key, selectors.SelectorKey)
66        self.assertEqual(key.fileobj, rd)
67        self.assertEqual(key.fd, rd.fileno())
68        self.assertEqual(key.events, selectors.EVENT_READ)
69        self.assertEqual(key.data, "data")
70
71        # register an unknown event
72        self.assertRaises(ValueError, s.register, 0, 999999)
73
74        # register an invalid FD
75        self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
76
77        # register twice
78        self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
79
80        # register the same FD, but with a different object
81        self.assertRaises(KeyError, s.register, rd.fileno(),
82                          selectors.EVENT_READ)
83
84    def test_unregister(self):
85        s = self.SELECTOR()
86        self.addCleanup(s.close)
87
88        rd, wr = self.make_socketpair()
89
90        s.register(rd, selectors.EVENT_READ)
91        s.unregister(rd)
92
93        # unregister an unknown file obj
94        self.assertRaises(KeyError, s.unregister, 999999)
95
96        # unregister twice
97        self.assertRaises(KeyError, s.unregister, rd)
98
99    def test_unregister_after_fd_close(self):
100        s = self.SELECTOR()
101        self.addCleanup(s.close)
102        rd, wr = self.make_socketpair()
103        r, w = rd.fileno(), wr.fileno()
104        s.register(r, selectors.EVENT_READ)
105        s.register(w, selectors.EVENT_WRITE)
106        rd.close()
107        wr.close()
108        s.unregister(r)
109        s.unregister(w)
110
111    @unittest.skipUnless(os.name == 'posix', "requires posix")
112    def test_unregister_after_fd_close_and_reuse(self):
113        s = self.SELECTOR()
114        self.addCleanup(s.close)
115        rd, wr = self.make_socketpair()
116        r, w = rd.fileno(), wr.fileno()
117        s.register(r, selectors.EVENT_READ)
118        s.register(w, selectors.EVENT_WRITE)
119        rd2, wr2 = self.make_socketpair()
120        rd.close()
121        wr.close()
122        os.dup2(rd2.fileno(), r)
123        os.dup2(wr2.fileno(), w)
124        self.addCleanup(os.close, r)
125        self.addCleanup(os.close, w)
126        s.unregister(r)
127        s.unregister(w)
128
129    def test_unregister_after_socket_close(self):
130        s = self.SELECTOR()
131        self.addCleanup(s.close)
132        rd, wr = self.make_socketpair()
133        s.register(rd, selectors.EVENT_READ)
134        s.register(wr, selectors.EVENT_WRITE)
135        rd.close()
136        wr.close()
137        s.unregister(rd)
138        s.unregister(wr)
139
140    def test_modify(self):
141        s = self.SELECTOR()
142        self.addCleanup(s.close)
143
144        rd, wr = self.make_socketpair()
145
146        key = s.register(rd, selectors.EVENT_READ)
147
148        # modify events
149        key2 = s.modify(rd, selectors.EVENT_WRITE)
150        self.assertNotEqual(key.events, key2.events)
151        self.assertEqual(key2, s.get_key(rd))
152
153        s.unregister(rd)
154
155        # modify data
156        d1 = object()
157        d2 = object()
158
159        key = s.register(rd, selectors.EVENT_READ, d1)
160        key2 = s.modify(rd, selectors.EVENT_READ, d2)
161        self.assertEqual(key.events, key2.events)
162        self.assertNotEqual(key.data, key2.data)
163        self.assertEqual(key2, s.get_key(rd))
164        self.assertEqual(key2.data, d2)
165
166        # modify unknown file obj
167        self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
168
169        # modify use a shortcut
170        d3 = object()
171        s.register = unittest.mock.Mock()
172        s.unregister = unittest.mock.Mock()
173
174        s.modify(rd, selectors.EVENT_READ, d3)
175        self.assertFalse(s.register.called)
176        self.assertFalse(s.unregister.called)
177
178    def test_modify_unregister(self):
179        # Make sure the fd is unregister()ed in case of error on
180        # modify(): http://bugs.python.org/issue30014
181        if self.SELECTOR.__name__ == 'EpollSelector':
182            patch = unittest.mock.patch(
183                'selectors.EpollSelector._selector_cls')
184        elif self.SELECTOR.__name__ == 'PollSelector':
185            patch = unittest.mock.patch(
186                'selectors.PollSelector._selector_cls')
187        elif self.SELECTOR.__name__ == 'DevpollSelector':
188            patch = unittest.mock.patch(
189                'selectors.DevpollSelector._selector_cls')
190        else:
191            raise self.skipTest("")
192
193        with patch as m:
194            m.return_value.modify = unittest.mock.Mock(
195                side_effect=ZeroDivisionError)
196            s = self.SELECTOR()
197            self.addCleanup(s.close)
198            rd, wr = self.make_socketpair()
199            s.register(rd, selectors.EVENT_READ)
200            self.assertEqual(len(s._map), 1)
201            with self.assertRaises(ZeroDivisionError):
202                s.modify(rd, selectors.EVENT_WRITE)
203            self.assertEqual(len(s._map), 0)
204
205    def test_close(self):
206        s = self.SELECTOR()
207        self.addCleanup(s.close)
208
209        mapping = s.get_map()
210        rd, wr = self.make_socketpair()
211
212        s.register(rd, selectors.EVENT_READ)
213        s.register(wr, selectors.EVENT_WRITE)
214
215        s.close()
216        self.assertRaises(RuntimeError, s.get_key, rd)
217        self.assertRaises(RuntimeError, s.get_key, wr)
218        self.assertRaises(KeyError, mapping.__getitem__, rd)
219        self.assertRaises(KeyError, mapping.__getitem__, wr)
220
221    def test_get_key(self):
222        s = self.SELECTOR()
223        self.addCleanup(s.close)
224
225        rd, wr = self.make_socketpair()
226
227        key = s.register(rd, selectors.EVENT_READ, "data")
228        self.assertEqual(key, s.get_key(rd))
229
230        # unknown file obj
231        self.assertRaises(KeyError, s.get_key, 999999)
232
233    def test_get_map(self):
234        s = self.SELECTOR()
235        self.addCleanup(s.close)
236
237        rd, wr = self.make_socketpair()
238
239        keys = s.get_map()
240        self.assertFalse(keys)
241        self.assertEqual(len(keys), 0)
242        self.assertEqual(list(keys), [])
243        key = s.register(rd, selectors.EVENT_READ, "data")
244        self.assertIn(rd, keys)
245        self.assertEqual(key, keys[rd])
246        self.assertEqual(len(keys), 1)
247        self.assertEqual(list(keys), [rd.fileno()])
248        self.assertEqual(list(keys.values()), [key])
249
250        # unknown file obj
251        with self.assertRaises(KeyError):
252            keys[999999]
253
254        # Read-only mapping
255        with self.assertRaises(TypeError):
256            del keys[rd]
257
258    def test_select(self):
259        s = self.SELECTOR()
260        self.addCleanup(s.close)
261
262        rd, wr = self.make_socketpair()
263
264        s.register(rd, selectors.EVENT_READ)
265        wr_key = s.register(wr, selectors.EVENT_WRITE)
266
267        result = s.select()
268        for key, events in result:
269            self.assertTrue(isinstance(key, selectors.SelectorKey))
270            self.assertTrue(events)
271            self.assertFalse(events & ~(selectors.EVENT_READ |
272                                        selectors.EVENT_WRITE))
273
274        self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
275
276    def test_context_manager(self):
277        s = self.SELECTOR()
278        self.addCleanup(s.close)
279
280        rd, wr = self.make_socketpair()
281
282        with s as sel:
283            sel.register(rd, selectors.EVENT_READ)
284            sel.register(wr, selectors.EVENT_WRITE)
285
286        self.assertRaises(RuntimeError, s.get_key, rd)
287        self.assertRaises(RuntimeError, s.get_key, wr)
288
289    def test_fileno(self):
290        s = self.SELECTOR()
291        self.addCleanup(s.close)
292
293        if hasattr(s, 'fileno'):
294            fd = s.fileno()
295            self.assertTrue(isinstance(fd, int))
296            self.assertGreaterEqual(fd, 0)
297
298    def test_selector(self):
299        s = self.SELECTOR()
300        self.addCleanup(s.close)
301
302        NUM_SOCKETS = 12
303        MSG = b" This is a test."
304        MSG_LEN = len(MSG)
305        readers = []
306        writers = []
307        r2w = {}
308        w2r = {}
309
310        for i in range(NUM_SOCKETS):
311            rd, wr = self.make_socketpair()
312            s.register(rd, selectors.EVENT_READ)
313            s.register(wr, selectors.EVENT_WRITE)
314            readers.append(rd)
315            writers.append(wr)
316            r2w[rd] = wr
317            w2r[wr] = rd
318
319        bufs = []
320
321        while writers:
322            ready = s.select()
323            ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
324            if not ready_writers:
325                self.fail("no sockets ready for writing")
326            wr = random.choice(ready_writers)
327            wr.send(MSG)
328
329            for i in range(10):
330                ready = s.select()
331                ready_readers = find_ready_matching(ready,
332                                                    selectors.EVENT_READ)
333                if ready_readers:
334                    break
335                # there might be a delay between the write to the write end and
336                # the read end is reported ready
337                sleep(0.1)
338            else:
339                self.fail("no sockets ready for reading")
340            self.assertEqual([w2r[wr]], ready_readers)
341            rd = ready_readers[0]
342            buf = rd.recv(MSG_LEN)
343            self.assertEqual(len(buf), MSG_LEN)
344            bufs.append(buf)
345            s.unregister(r2w[rd])
346            s.unregister(rd)
347            writers.remove(r2w[rd])
348
349        self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
350
351    @unittest.skipIf(sys.platform == 'win32',
352                     'select.select() cannot be used with empty fd sets')
353    def test_empty_select(self):
354        # Issue #23009: Make sure EpollSelector.select() works when no FD is
355        # registered.
356        s = self.SELECTOR()
357        self.addCleanup(s.close)
358        self.assertEqual(s.select(timeout=0), [])
359
360    def test_timeout(self):
361        s = self.SELECTOR()
362        self.addCleanup(s.close)
363
364        rd, wr = self.make_socketpair()
365
366        s.register(wr, selectors.EVENT_WRITE)
367        t = time()
368        self.assertEqual(1, len(s.select(0)))
369        self.assertEqual(1, len(s.select(-1)))
370        self.assertLess(time() - t, 0.5)
371
372        s.unregister(wr)
373        s.register(rd, selectors.EVENT_READ)
374        t = time()
375        self.assertFalse(s.select(0))
376        self.assertFalse(s.select(-1))
377        self.assertLess(time() - t, 0.5)
378
379        t0 = time()
380        self.assertFalse(s.select(1))
381        t1 = time()
382        dt = t1 - t0
383        # Tolerate 2.0 seconds for very slow buildbots
384        self.assertTrue(0.8 <= dt <= 2.0, dt)
385
386    @unittest.skipUnless(hasattr(signal, "alarm"),
387                         "signal.alarm() required for this test")
388    def test_select_interrupt_exc(self):
389        s = self.SELECTOR()
390        self.addCleanup(s.close)
391
392        rd, wr = self.make_socketpair()
393
394        class InterruptSelect(Exception):
395            pass
396
397        def handler(*args):
398            raise InterruptSelect
399
400        orig_alrm_handler = signal.signal(signal.SIGALRM, handler)
401        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
402
403        try:
404            signal.alarm(1)
405
406            s.register(rd, selectors.EVENT_READ)
407            t = time()
408            # select() is interrupted by a signal which raises an exception
409            with self.assertRaises(InterruptSelect):
410                s.select(30)
411            # select() was interrupted before the timeout of 30 seconds
412            self.assertLess(time() - t, 5.0)
413        finally:
414            signal.alarm(0)
415
416    @unittest.skipUnless(hasattr(signal, "alarm"),
417                         "signal.alarm() required for this test")
418    def test_select_interrupt_noraise(self):
419        s = self.SELECTOR()
420        self.addCleanup(s.close)
421
422        rd, wr = self.make_socketpair()
423
424        orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
425        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
426
427        try:
428            signal.alarm(1)
429
430            s.register(rd, selectors.EVENT_READ)
431            t = time()
432            # select() is interrupted by a signal, but the signal handler doesn't
433            # raise an exception, so select() should by retries with a recomputed
434            # timeout
435            self.assertFalse(s.select(1.5))
436            self.assertGreaterEqual(time() - t, 1.0)
437        finally:
438            signal.alarm(0)
439
440
441class ScalableSelectorMixIn:
442
443    # see issue #18963 for why it's skipped on older OS X versions
444    @support.requires_mac_ver(10, 5)
445    @unittest.skipUnless(resource, "Test needs resource module")
446    def test_above_fd_setsize(self):
447        # A scalable implementation should have no problem with more than
448        # FD_SETSIZE file descriptors. Since we don't know the value, we just
449        # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
450        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
451        try:
452            resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
453            self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
454                            (soft, hard))
455            NUM_FDS = min(hard, 2**16)
456        except (OSError, ValueError):
457            NUM_FDS = soft
458
459        # guard for already allocated FDs (stdin, stdout...)
460        NUM_FDS -= 32
461
462        s = self.SELECTOR()
463        self.addCleanup(s.close)
464
465        for i in range(NUM_FDS // 2):
466            try:
467                rd, wr = self.make_socketpair()
468            except OSError:
469                # too many FDs, skip - note that we should only catch EMFILE
470                # here, but apparently *BSD and Solaris can fail upon connect()
471                # or bind() with EADDRNOTAVAIL, so let's be safe
472                self.skipTest("FD limit reached")
473
474            try:
475                s.register(rd, selectors.EVENT_READ)
476                s.register(wr, selectors.EVENT_WRITE)
477            except OSError as e:
478                if e.errno == errno.ENOSPC:
479                    # this can be raised by epoll if we go over
480                    # fs.epoll.max_user_watches sysctl
481                    self.skipTest("FD limit reached")
482                raise
483
484        try:
485            fds = s.select()
486        except OSError as e:
487            if e.errno == errno.EINVAL and sys.platform == 'darwin':
488                # unexplainable errors on macOS don't need to fail the test
489                self.skipTest("Invalid argument error calling poll()")
490            raise
491        self.assertEqual(NUM_FDS // 2, len(fds))
492
493
494class DefaultSelectorTestCase(BaseSelectorTestCase):
495
496    SELECTOR = selectors.DefaultSelector
497
498
499class SelectSelectorTestCase(BaseSelectorTestCase):
500
501    SELECTOR = selectors.SelectSelector
502
503
504@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
505                     "Test needs selectors.PollSelector")
506class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
507
508    SELECTOR = getattr(selectors, 'PollSelector', None)
509
510
511@unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
512                     "Test needs selectors.EpollSelector")
513class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
514
515    SELECTOR = getattr(selectors, 'EpollSelector', None)
516
517    def test_register_file(self):
518        # epoll(7) returns EPERM when given a file to watch
519        s = self.SELECTOR()
520        with tempfile.NamedTemporaryFile() as f:
521            with self.assertRaises(IOError):
522                s.register(f, selectors.EVENT_READ)
523            # the SelectorKey has been removed
524            with self.assertRaises(KeyError):
525                s.get_key(f)
526
527
528@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
529                     "Test needs selectors.KqueueSelector)")
530class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
531
532    SELECTOR = getattr(selectors, 'KqueueSelector', None)
533
534    def test_register_bad_fd(self):
535        # a file descriptor that's been closed should raise an OSError
536        # with EBADF
537        s = self.SELECTOR()
538        bad_f = support.make_bad_fd()
539        with self.assertRaises(OSError) as cm:
540            s.register(bad_f, selectors.EVENT_READ)
541        self.assertEqual(cm.exception.errno, errno.EBADF)
542        # the SelectorKey has been removed
543        with self.assertRaises(KeyError):
544            s.get_key(bad_f)
545
546
547@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'),
548                     "Test needs selectors.DevpollSelector")
549class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
550
551    SELECTOR = getattr(selectors, 'DevpollSelector', None)
552
553
554
555def test_main():
556    tests = [DefaultSelectorTestCase, SelectSelectorTestCase,
557             PollSelectorTestCase, EpollSelectorTestCase,
558             KqueueSelectorTestCase, DevpollSelectorTestCase]
559    support.run_unittest(*tests)
560    support.reap_children()
561
562
563if __name__ == "__main__":
564    test_main()
565