• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import errno
2import os
3import random
4import selectors
5import signal
6import socket
7import sys
8from test import support
9from test.support import socket_helper
10from time import sleep
11import unittest
12import unittest.mock
13import tempfile
14from time import monotonic as time
15try:
16    import resource
17except ImportError:
18    resource = None
19
20
21if hasattr(socket, 'socketpair'):
22    socketpair = socket.socketpair
23else:
24    def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
25        with socket.socket(family, type, proto) as l:
26            l.bind((socket_helper.HOST, 0))
27            l.listen()
28            c = socket.socket(family, type, proto)
29            try:
30                c.connect(l.getsockname())
31                caddr = c.getsockname()
32                while True:
33                    a, addr = l.accept()
34                    # check that we've got the correct client
35                    if addr == caddr:
36                        return c, a
37                    a.close()
38            except OSError:
39                c.close()
40                raise
41
42
43def find_ready_matching(ready, flag):
44    match = []
45    for key, events in ready:
46        if events & flag:
47            match.append(key.fileobj)
48    return match
49
50
51class BaseSelectorTestCase(unittest.TestCase):
52
53    def make_socketpair(self):
54        rd, wr = socketpair()
55        self.addCleanup(rd.close)
56        self.addCleanup(wr.close)
57        return rd, wr
58
59    def test_register(self):
60        s = self.SELECTOR()
61        self.addCleanup(s.close)
62
63        rd, wr = self.make_socketpair()
64
65        key = s.register(rd, selectors.EVENT_READ, "data")
66        self.assertIsInstance(key, selectors.SelectorKey)
67        self.assertEqual(key.fileobj, rd)
68        self.assertEqual(key.fd, rd.fileno())
69        self.assertEqual(key.events, selectors.EVENT_READ)
70        self.assertEqual(key.data, "data")
71
72        # register an unknown event
73        self.assertRaises(ValueError, s.register, 0, 999999)
74
75        # register an invalid FD
76        self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
77
78        # register twice
79        self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
80
81        # register the same FD, but with a different object
82        self.assertRaises(KeyError, s.register, rd.fileno(),
83                          selectors.EVENT_READ)
84
85    def test_unregister(self):
86        s = self.SELECTOR()
87        self.addCleanup(s.close)
88
89        rd, wr = self.make_socketpair()
90
91        s.register(rd, selectors.EVENT_READ)
92        s.unregister(rd)
93
94        # unregister an unknown file obj
95        self.assertRaises(KeyError, s.unregister, 999999)
96
97        # unregister twice
98        self.assertRaises(KeyError, s.unregister, rd)
99
100    def test_unregister_after_fd_close(self):
101        s = self.SELECTOR()
102        self.addCleanup(s.close)
103        rd, wr = self.make_socketpair()
104        r, w = rd.fileno(), wr.fileno()
105        s.register(r, selectors.EVENT_READ)
106        s.register(w, selectors.EVENT_WRITE)
107        rd.close()
108        wr.close()
109        s.unregister(r)
110        s.unregister(w)
111
112    @unittest.skipUnless(os.name == 'posix', "requires posix")
113    def test_unregister_after_fd_close_and_reuse(self):
114        s = self.SELECTOR()
115        self.addCleanup(s.close)
116        rd, wr = self.make_socketpair()
117        r, w = rd.fileno(), wr.fileno()
118        s.register(r, selectors.EVENT_READ)
119        s.register(w, selectors.EVENT_WRITE)
120        rd2, wr2 = self.make_socketpair()
121        rd.close()
122        wr.close()
123        os.dup2(rd2.fileno(), r)
124        os.dup2(wr2.fileno(), w)
125        self.addCleanup(os.close, r)
126        self.addCleanup(os.close, w)
127        s.unregister(r)
128        s.unregister(w)
129
130    def test_unregister_after_socket_close(self):
131        s = self.SELECTOR()
132        self.addCleanup(s.close)
133        rd, wr = self.make_socketpair()
134        s.register(rd, selectors.EVENT_READ)
135        s.register(wr, selectors.EVENT_WRITE)
136        rd.close()
137        wr.close()
138        s.unregister(rd)
139        s.unregister(wr)
140
141    def test_modify(self):
142        s = self.SELECTOR()
143        self.addCleanup(s.close)
144
145        rd, wr = self.make_socketpair()
146
147        key = s.register(rd, selectors.EVENT_READ)
148
149        # modify events
150        key2 = s.modify(rd, selectors.EVENT_WRITE)
151        self.assertNotEqual(key.events, key2.events)
152        self.assertEqual(key2, s.get_key(rd))
153
154        s.unregister(rd)
155
156        # modify data
157        d1 = object()
158        d2 = object()
159
160        key = s.register(rd, selectors.EVENT_READ, d1)
161        key2 = s.modify(rd, selectors.EVENT_READ, d2)
162        self.assertEqual(key.events, key2.events)
163        self.assertNotEqual(key.data, key2.data)
164        self.assertEqual(key2, s.get_key(rd))
165        self.assertEqual(key2.data, d2)
166
167        # modify unknown file obj
168        self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
169
170        # modify use a shortcut
171        d3 = object()
172        s.register = unittest.mock.Mock()
173        s.unregister = unittest.mock.Mock()
174
175        s.modify(rd, selectors.EVENT_READ, d3)
176        self.assertFalse(s.register.called)
177        self.assertFalse(s.unregister.called)
178
179    def test_modify_unregister(self):
180        # Make sure the fd is unregister()ed in case of error on
181        # modify(): http://bugs.python.org/issue30014
182        if self.SELECTOR.__name__ == 'EpollSelector':
183            patch = unittest.mock.patch(
184                'selectors.EpollSelector._selector_cls')
185        elif self.SELECTOR.__name__ == 'PollSelector':
186            patch = unittest.mock.patch(
187                'selectors.PollSelector._selector_cls')
188        elif self.SELECTOR.__name__ == 'DevpollSelector':
189            patch = unittest.mock.patch(
190                'selectors.DevpollSelector._selector_cls')
191        else:
192            raise self.skipTest("")
193
194        with patch as m:
195            m.return_value.modify = unittest.mock.Mock(
196                side_effect=ZeroDivisionError)
197            s = self.SELECTOR()
198            self.addCleanup(s.close)
199            rd, wr = self.make_socketpair()
200            s.register(rd, selectors.EVENT_READ)
201            self.assertEqual(len(s._map), 1)
202            with self.assertRaises(ZeroDivisionError):
203                s.modify(rd, selectors.EVENT_WRITE)
204            self.assertEqual(len(s._map), 0)
205
206    def test_close(self):
207        s = self.SELECTOR()
208        self.addCleanup(s.close)
209
210        mapping = s.get_map()
211        rd, wr = self.make_socketpair()
212
213        s.register(rd, selectors.EVENT_READ)
214        s.register(wr, selectors.EVENT_WRITE)
215
216        s.close()
217        self.assertRaises(RuntimeError, s.get_key, rd)
218        self.assertRaises(RuntimeError, s.get_key, wr)
219        self.assertRaises(KeyError, mapping.__getitem__, rd)
220        self.assertRaises(KeyError, mapping.__getitem__, wr)
221
222    def test_get_key(self):
223        s = self.SELECTOR()
224        self.addCleanup(s.close)
225
226        rd, wr = self.make_socketpair()
227
228        key = s.register(rd, selectors.EVENT_READ, "data")
229        self.assertEqual(key, s.get_key(rd))
230
231        # unknown file obj
232        self.assertRaises(KeyError, s.get_key, 999999)
233
234    def test_get_map(self):
235        s = self.SELECTOR()
236        self.addCleanup(s.close)
237
238        rd, wr = self.make_socketpair()
239
240        keys = s.get_map()
241        self.assertFalse(keys)
242        self.assertEqual(len(keys), 0)
243        self.assertEqual(list(keys), [])
244        key = s.register(rd, selectors.EVENT_READ, "data")
245        self.assertIn(rd, keys)
246        self.assertEqual(key, keys[rd])
247        self.assertEqual(len(keys), 1)
248        self.assertEqual(list(keys), [rd.fileno()])
249        self.assertEqual(list(keys.values()), [key])
250
251        # unknown file obj
252        with self.assertRaises(KeyError):
253            keys[999999]
254
255        # Read-only mapping
256        with self.assertRaises(TypeError):
257            del keys[rd]
258
259    def test_select(self):
260        s = self.SELECTOR()
261        self.addCleanup(s.close)
262
263        rd, wr = self.make_socketpair()
264
265        s.register(rd, selectors.EVENT_READ)
266        wr_key = s.register(wr, selectors.EVENT_WRITE)
267
268        result = s.select()
269        for key, events in result:
270            self.assertTrue(isinstance(key, selectors.SelectorKey))
271            self.assertTrue(events)
272            self.assertFalse(events & ~(selectors.EVENT_READ |
273                                        selectors.EVENT_WRITE))
274
275        self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
276
277    def test_context_manager(self):
278        s = self.SELECTOR()
279        self.addCleanup(s.close)
280
281        rd, wr = self.make_socketpair()
282
283        with s as sel:
284            sel.register(rd, selectors.EVENT_READ)
285            sel.register(wr, selectors.EVENT_WRITE)
286
287        self.assertRaises(RuntimeError, s.get_key, rd)
288        self.assertRaises(RuntimeError, s.get_key, wr)
289
290    def test_fileno(self):
291        s = self.SELECTOR()
292        self.addCleanup(s.close)
293
294        if hasattr(s, 'fileno'):
295            fd = s.fileno()
296            self.assertTrue(isinstance(fd, int))
297            self.assertGreaterEqual(fd, 0)
298
299    def test_selector(self):
300        s = self.SELECTOR()
301        self.addCleanup(s.close)
302
303        NUM_SOCKETS = 12
304        MSG = b" This is a test."
305        MSG_LEN = len(MSG)
306        readers = []
307        writers = []
308        r2w = {}
309        w2r = {}
310
311        for i in range(NUM_SOCKETS):
312            rd, wr = self.make_socketpair()
313            s.register(rd, selectors.EVENT_READ)
314            s.register(wr, selectors.EVENT_WRITE)
315            readers.append(rd)
316            writers.append(wr)
317            r2w[rd] = wr
318            w2r[wr] = rd
319
320        bufs = []
321
322        while writers:
323            ready = s.select()
324            ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
325            if not ready_writers:
326                self.fail("no sockets ready for writing")
327            wr = random.choice(ready_writers)
328            wr.send(MSG)
329
330            for i in range(10):
331                ready = s.select()
332                ready_readers = find_ready_matching(ready,
333                                                    selectors.EVENT_READ)
334                if ready_readers:
335                    break
336                # there might be a delay between the write to the write end and
337                # the read end is reported ready
338                sleep(0.1)
339            else:
340                self.fail("no sockets ready for reading")
341            self.assertEqual([w2r[wr]], ready_readers)
342            rd = ready_readers[0]
343            buf = rd.recv(MSG_LEN)
344            self.assertEqual(len(buf), MSG_LEN)
345            bufs.append(buf)
346            s.unregister(r2w[rd])
347            s.unregister(rd)
348            writers.remove(r2w[rd])
349
350        self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
351
352    @unittest.skipIf(sys.platform == 'win32',
353                     'select.select() cannot be used with empty fd sets')
354    def test_empty_select(self):
355        # Issue #23009: Make sure EpollSelector.select() works when no FD is
356        # registered.
357        s = self.SELECTOR()
358        self.addCleanup(s.close)
359        self.assertEqual(s.select(timeout=0), [])
360
361    def test_timeout(self):
362        s = self.SELECTOR()
363        self.addCleanup(s.close)
364
365        rd, wr = self.make_socketpair()
366
367        s.register(wr, selectors.EVENT_WRITE)
368        t = time()
369        self.assertEqual(1, len(s.select(0)))
370        self.assertEqual(1, len(s.select(-1)))
371        self.assertLess(time() - t, 0.5)
372
373        s.unregister(wr)
374        s.register(rd, selectors.EVENT_READ)
375        t = time()
376        self.assertFalse(s.select(0))
377        self.assertFalse(s.select(-1))
378        self.assertLess(time() - t, 0.5)
379
380        t0 = time()
381        self.assertFalse(s.select(1))
382        t1 = time()
383        dt = t1 - t0
384        # Tolerate 2.0 seconds for very slow buildbots
385        self.assertTrue(0.8 <= dt <= 2.0, dt)
386
387    @unittest.skipUnless(hasattr(signal, "alarm"),
388                         "signal.alarm() required for this test")
389    def test_select_interrupt_exc(self):
390        s = self.SELECTOR()
391        self.addCleanup(s.close)
392
393        rd, wr = self.make_socketpair()
394
395        class InterruptSelect(Exception):
396            pass
397
398        def handler(*args):
399            raise InterruptSelect
400
401        orig_alrm_handler = signal.signal(signal.SIGALRM, handler)
402        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
403
404        try:
405            signal.alarm(1)
406
407            s.register(rd, selectors.EVENT_READ)
408            t = time()
409            # select() is interrupted by a signal which raises an exception
410            with self.assertRaises(InterruptSelect):
411                s.select(30)
412            # select() was interrupted before the timeout of 30 seconds
413            self.assertLess(time() - t, 5.0)
414        finally:
415            signal.alarm(0)
416
417    @unittest.skipUnless(hasattr(signal, "alarm"),
418                         "signal.alarm() required for this test")
419    def test_select_interrupt_noraise(self):
420        s = self.SELECTOR()
421        self.addCleanup(s.close)
422
423        rd, wr = self.make_socketpair()
424
425        orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
426        self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
427
428        try:
429            signal.alarm(1)
430
431            s.register(rd, selectors.EVENT_READ)
432            t = time()
433            # select() is interrupted by a signal, but the signal handler doesn't
434            # raise an exception, so select() should by retries with a recomputed
435            # timeout
436            self.assertFalse(s.select(1.5))
437            self.assertGreaterEqual(time() - t, 1.0)
438        finally:
439            signal.alarm(0)
440
441
442class ScalableSelectorMixIn:
443
444    # see issue #18963 for why it's skipped on older OS X versions
445    @support.requires_mac_ver(10, 5)
446    @unittest.skipUnless(resource, "Test needs resource module")
447    def test_above_fd_setsize(self):
448        # A scalable implementation should have no problem with more than
449        # FD_SETSIZE file descriptors. Since we don't know the value, we just
450        # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
451        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
452        try:
453            resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
454            self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
455                            (soft, hard))
456            NUM_FDS = min(hard, 2**16)
457        except (OSError, ValueError):
458            NUM_FDS = soft
459
460        # guard for already allocated FDs (stdin, stdout...)
461        NUM_FDS -= 32
462
463        s = self.SELECTOR()
464        self.addCleanup(s.close)
465
466        for i in range(NUM_FDS // 2):
467            try:
468                rd, wr = self.make_socketpair()
469            except OSError:
470                # too many FDs, skip - note that we should only catch EMFILE
471                # here, but apparently *BSD and Solaris can fail upon connect()
472                # or bind() with EADDRNOTAVAIL, so let's be safe
473                self.skipTest("FD limit reached")
474
475            try:
476                s.register(rd, selectors.EVENT_READ)
477                s.register(wr, selectors.EVENT_WRITE)
478            except OSError as e:
479                if e.errno == errno.ENOSPC:
480                    # this can be raised by epoll if we go over
481                    # fs.epoll.max_user_watches sysctl
482                    self.skipTest("FD limit reached")
483                raise
484
485        try:
486            fds = s.select()
487        except OSError as e:
488            if e.errno == errno.EINVAL and sys.platform == 'darwin':
489                # unexplainable errors on macOS don't need to fail the test
490                self.skipTest("Invalid argument error calling poll()")
491            raise
492        self.assertEqual(NUM_FDS // 2, len(fds))
493
494
495class DefaultSelectorTestCase(BaseSelectorTestCase):
496
497    SELECTOR = selectors.DefaultSelector
498
499
500class SelectSelectorTestCase(BaseSelectorTestCase):
501
502    SELECTOR = selectors.SelectSelector
503
504
505@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
506                     "Test needs selectors.PollSelector")
507class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
508
509    SELECTOR = getattr(selectors, 'PollSelector', None)
510
511
512@unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
513                     "Test needs selectors.EpollSelector")
514class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
515
516    SELECTOR = getattr(selectors, 'EpollSelector', None)
517
518    def test_register_file(self):
519        # epoll(7) returns EPERM when given a file to watch
520        s = self.SELECTOR()
521        with tempfile.NamedTemporaryFile() as f:
522            with self.assertRaises(IOError):
523                s.register(f, selectors.EVENT_READ)
524            # the SelectorKey has been removed
525            with self.assertRaises(KeyError):
526                s.get_key(f)
527
528
529@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
530                     "Test needs selectors.KqueueSelector)")
531class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
532
533    SELECTOR = getattr(selectors, 'KqueueSelector', None)
534
535    def test_register_bad_fd(self):
536        # a file descriptor that's been closed should raise an OSError
537        # with EBADF
538        s = self.SELECTOR()
539        bad_f = support.make_bad_fd()
540        with self.assertRaises(OSError) as cm:
541            s.register(bad_f, selectors.EVENT_READ)
542        self.assertEqual(cm.exception.errno, errno.EBADF)
543        # the SelectorKey has been removed
544        with self.assertRaises(KeyError):
545            s.get_key(bad_f)
546
547    def test_empty_select_timeout(self):
548        # Issues #23009, #29255: Make sure timeout is applied when no fds
549        # are registered.
550        s = self.SELECTOR()
551        self.addCleanup(s.close)
552
553        t0 = time()
554        self.assertEqual(s.select(1), [])
555        t1 = time()
556        dt = t1 - t0
557        # Tolerate 2.0 seconds for very slow buildbots
558        self.assertTrue(0.8 <= dt <= 2.0, dt)
559
560
561@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'),
562                     "Test needs selectors.DevpollSelector")
563class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
564
565    SELECTOR = getattr(selectors, 'DevpollSelector', None)
566
567
568
569def test_main():
570    tests = [DefaultSelectorTestCase, SelectSelectorTestCase,
571             PollSelectorTestCase, EpollSelectorTestCase,
572             KqueueSelectorTestCase, DevpollSelectorTestCase]
573    support.run_unittest(*tests)
574    support.reap_children()
575
576
577if __name__ == "__main__":
578    test_main()
579