• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Various tests for synchronization primitives.
3"""
4
5import gc
6import sys
7import time
8from _thread import start_new_thread, TIMEOUT_MAX
9import threading
10import unittest
11import weakref
12
13from test import support
14from test.support import threading_helper
15
16
17requires_fork = unittest.skipUnless(support.has_fork_support,
18                                    "platform doesn't support fork "
19                                     "(no _at_fork_reinit method)")
20
21
22def wait_threads_blocked(nthread):
23    # Arbitrary sleep to wait until N threads are blocked,
24    # like waiting for a lock.
25    time.sleep(0.010 * nthread)
26
27
28class Bunch(object):
29    """
30    A bunch of threads.
31    """
32    def __init__(self, func, nthread, wait_before_exit=False):
33        """
34        Construct a bunch of `nthread` threads running the same function `func`.
35        If `wait_before_exit` is True, the threads won't terminate until
36        do_finish() is called.
37        """
38        self.func = func
39        self.nthread = nthread
40        self.started = []
41        self.finished = []
42        self.exceptions = []
43        self._can_exit = not wait_before_exit
44        self._wait_thread = None
45
46    def task(self):
47        tid = threading.get_ident()
48        self.started.append(tid)
49        try:
50            self.func()
51        except BaseException as exc:
52            self.exceptions.append(exc)
53        finally:
54            self.finished.append(tid)
55            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
56                if self._can_exit:
57                    break
58
59    def __enter__(self):
60        self._wait_thread = threading_helper.wait_threads_exit(support.SHORT_TIMEOUT)
61        self._wait_thread.__enter__()
62
63        try:
64            for _ in range(self.nthread):
65                start_new_thread(self.task, ())
66        except:
67            self._can_exit = True
68            raise
69
70        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
71            if len(self.started) >= self.nthread:
72                break
73
74        return self
75
76    def __exit__(self, exc_type, exc_value, traceback):
77        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
78            if len(self.finished) >= self.nthread:
79                break
80
81        # Wait until threads completely exit according to _thread._count()
82        self._wait_thread.__exit__(None, None, None)
83
84        # Break reference cycle
85        exceptions = self.exceptions
86        self.exceptions = None
87        if exceptions:
88            raise ExceptionGroup(f"{self.func} threads raised exceptions",
89                                 exceptions)
90
91    def do_finish(self):
92        self._can_exit = True
93
94
95class BaseTestCase(unittest.TestCase):
96    def setUp(self):
97        self._threads = threading_helper.threading_setup()
98
99    def tearDown(self):
100        threading_helper.threading_cleanup(*self._threads)
101        support.reap_children()
102
103    def assertTimeout(self, actual, expected):
104        # The waiting and/or time.monotonic() can be imprecise, which
105        # is why comparing to the expected value would sometimes fail
106        # (especially under Windows).
107        self.assertGreaterEqual(actual, expected * 0.6)
108        # Test nothing insane happened
109        self.assertLess(actual, expected * 10.0)
110
111
112class BaseLockTests(BaseTestCase):
113    """
114    Tests for both recursive and non-recursive locks.
115    """
116
117    def wait_phase(self, phase, expected):
118        for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
119            if len(phase) >= expected:
120                break
121        self.assertEqual(len(phase), expected)
122
123    def test_constructor(self):
124        lock = self.locktype()
125        del lock
126
127    def test_repr(self):
128        lock = self.locktype()
129        self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>")
130        del lock
131
132    def test_locked_repr(self):
133        lock = self.locktype()
134        lock.acquire()
135        self.assertRegex(repr(lock), "<locked .* object (.*)?at .*>")
136        del lock
137
138    def test_acquire_destroy(self):
139        lock = self.locktype()
140        lock.acquire()
141        del lock
142
143    def test_acquire_release(self):
144        lock = self.locktype()
145        lock.acquire()
146        lock.release()
147        del lock
148
149    def test_try_acquire(self):
150        lock = self.locktype()
151        self.assertTrue(lock.acquire(False))
152        lock.release()
153
154    def test_try_acquire_contended(self):
155        lock = self.locktype()
156        lock.acquire()
157        result = []
158        def f():
159            result.append(lock.acquire(False))
160        with Bunch(f, 1):
161            pass
162        self.assertFalse(result[0])
163        lock.release()
164
165    def test_acquire_contended(self):
166        lock = self.locktype()
167        lock.acquire()
168        def f():
169            lock.acquire()
170            lock.release()
171
172        N = 5
173        with Bunch(f, N) as bunch:
174            # Threads block on lock.acquire()
175            wait_threads_blocked(N)
176            self.assertEqual(len(bunch.finished), 0)
177
178            # Threads unblocked
179            lock.release()
180
181        self.assertEqual(len(bunch.finished), N)
182
183    def test_with(self):
184        lock = self.locktype()
185        def f():
186            lock.acquire()
187            lock.release()
188
189        def with_lock(err=None):
190            with lock:
191                if err is not None:
192                    raise err
193
194        # Acquire the lock, do nothing, with releases the lock
195        with lock:
196            pass
197
198        # Check that the lock is unacquired
199        with Bunch(f, 1):
200            pass
201
202        # Acquire the lock, raise an exception, with releases the lock
203        with self.assertRaises(TypeError):
204            with lock:
205                raise TypeError
206
207        # Check that the lock is unacquired even if after an exception
208        # was raised in the previous "with lock:" block
209        with Bunch(f, 1):
210            pass
211
212    def test_thread_leak(self):
213        # The lock shouldn't leak a Thread instance when used from a foreign
214        # (non-threading) thread.
215        lock = self.locktype()
216        def f():
217            lock.acquire()
218            lock.release()
219
220        # We run many threads in the hope that existing threads ids won't
221        # be recycled.
222        with Bunch(f, 15):
223            pass
224
225    def test_timeout(self):
226        lock = self.locktype()
227        # Can't set timeout if not blocking
228        self.assertRaises(ValueError, lock.acquire, False, 1)
229        # Invalid timeout values
230        self.assertRaises(ValueError, lock.acquire, timeout=-100)
231        self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
232        self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
233        # TIMEOUT_MAX is ok
234        lock.acquire(timeout=TIMEOUT_MAX)
235        lock.release()
236        t1 = time.monotonic()
237        self.assertTrue(lock.acquire(timeout=5))
238        t2 = time.monotonic()
239        # Just a sanity test that it didn't actually wait for the timeout.
240        self.assertLess(t2 - t1, 5)
241        results = []
242        def f():
243            t1 = time.monotonic()
244            results.append(lock.acquire(timeout=0.5))
245            t2 = time.monotonic()
246            results.append(t2 - t1)
247        with Bunch(f, 1):
248            pass
249        self.assertFalse(results[0])
250        self.assertTimeout(results[1], 0.5)
251
252    def test_weakref_exists(self):
253        lock = self.locktype()
254        ref = weakref.ref(lock)
255        self.assertIsNotNone(ref())
256
257    def test_weakref_deleted(self):
258        lock = self.locktype()
259        ref = weakref.ref(lock)
260        del lock
261        gc.collect()  # For PyPy or other GCs.
262        self.assertIsNone(ref())
263
264
265class LockTests(BaseLockTests):
266    """
267    Tests for non-recursive, weak locks
268    (which can be acquired and released from different threads).
269    """
270    def test_reacquire(self):
271        # Lock needs to be released before re-acquiring.
272        lock = self.locktype()
273        phase = []
274
275        def f():
276            lock.acquire()
277            phase.append(None)
278            lock.acquire()
279            phase.append(None)
280
281        with threading_helper.wait_threads_exit():
282            # Thread blocked on lock.acquire()
283            start_new_thread(f, ())
284            self.wait_phase(phase, 1)
285
286            # Thread unblocked
287            lock.release()
288            self.wait_phase(phase, 2)
289
290    def test_different_thread(self):
291        # Lock can be released from a different thread.
292        lock = self.locktype()
293        lock.acquire()
294        def f():
295            lock.release()
296        with Bunch(f, 1):
297            pass
298        lock.acquire()
299        lock.release()
300
301    def test_state_after_timeout(self):
302        # Issue #11618: check that lock is in a proper state after a
303        # (non-zero) timeout.
304        lock = self.locktype()
305        lock.acquire()
306        self.assertFalse(lock.acquire(timeout=0.01))
307        lock.release()
308        self.assertFalse(lock.locked())
309        self.assertTrue(lock.acquire(blocking=False))
310
311    @requires_fork
312    def test_at_fork_reinit(self):
313        def use_lock(lock):
314            # make sure that the lock still works normally
315            # after _at_fork_reinit()
316            lock.acquire()
317            lock.release()
318
319        # unlocked
320        lock = self.locktype()
321        lock._at_fork_reinit()
322        use_lock(lock)
323
324        # locked: _at_fork_reinit() resets the lock to the unlocked state
325        lock2 = self.locktype()
326        lock2.acquire()
327        lock2._at_fork_reinit()
328        use_lock(lock2)
329
330
331class RLockTests(BaseLockTests):
332    """
333    Tests for recursive locks.
334    """
335    def test_reacquire(self):
336        lock = self.locktype()
337        lock.acquire()
338        lock.acquire()
339        lock.release()
340        lock.acquire()
341        lock.release()
342        lock.release()
343
344    def test_release_unacquired(self):
345        # Cannot release an unacquired lock
346        lock = self.locktype()
347        self.assertRaises(RuntimeError, lock.release)
348        lock.acquire()
349        lock.acquire()
350        lock.release()
351        lock.acquire()
352        lock.release()
353        lock.release()
354        self.assertRaises(RuntimeError, lock.release)
355
356    def test_release_save_unacquired(self):
357        # Cannot _release_save an unacquired lock
358        lock = self.locktype()
359        self.assertRaises(RuntimeError, lock._release_save)
360        lock.acquire()
361        lock.acquire()
362        lock.release()
363        lock.acquire()
364        lock.release()
365        lock.release()
366        self.assertRaises(RuntimeError, lock._release_save)
367
368    def test_recursion_count(self):
369        lock = self.locktype()
370        self.assertEqual(0, lock._recursion_count())
371        lock.acquire()
372        self.assertEqual(1, lock._recursion_count())
373        lock.acquire()
374        lock.acquire()
375        self.assertEqual(3, lock._recursion_count())
376        lock.release()
377        self.assertEqual(2, lock._recursion_count())
378        lock.release()
379        lock.release()
380        self.assertEqual(0, lock._recursion_count())
381
382        phase = []
383
384        def f():
385            lock.acquire()
386            phase.append(None)
387
388            self.wait_phase(phase, 2)
389            lock.release()
390            phase.append(None)
391
392        with threading_helper.wait_threads_exit():
393            # Thread blocked on lock.acquire()
394            start_new_thread(f, ())
395            self.wait_phase(phase, 1)
396            self.assertEqual(0, lock._recursion_count())
397
398            # Thread unblocked
399            phase.append(None)
400            self.wait_phase(phase, 3)
401            self.assertEqual(0, lock._recursion_count())
402
403    def test_different_thread(self):
404        # Cannot release from a different thread
405        lock = self.locktype()
406        def f():
407            lock.acquire()
408
409        with Bunch(f, 1, True) as bunch:
410            try:
411                self.assertRaises(RuntimeError, lock.release)
412            finally:
413                bunch.do_finish()
414
415    def test__is_owned(self):
416        lock = self.locktype()
417        self.assertFalse(lock._is_owned())
418        lock.acquire()
419        self.assertTrue(lock._is_owned())
420        lock.acquire()
421        self.assertTrue(lock._is_owned())
422        result = []
423        def f():
424            result.append(lock._is_owned())
425        with Bunch(f, 1):
426            pass
427        self.assertFalse(result[0])
428        lock.release()
429        self.assertTrue(lock._is_owned())
430        lock.release()
431        self.assertFalse(lock._is_owned())
432
433
434class EventTests(BaseTestCase):
435    """
436    Tests for Event objects.
437    """
438
439    def test_is_set(self):
440        evt = self.eventtype()
441        self.assertFalse(evt.is_set())
442        evt.set()
443        self.assertTrue(evt.is_set())
444        evt.set()
445        self.assertTrue(evt.is_set())
446        evt.clear()
447        self.assertFalse(evt.is_set())
448        evt.clear()
449        self.assertFalse(evt.is_set())
450
451    def _check_notify(self, evt):
452        # All threads get notified
453        N = 5
454        results1 = []
455        results2 = []
456        def f():
457            results1.append(evt.wait())
458            results2.append(evt.wait())
459
460        with Bunch(f, N):
461            # Threads blocked on first evt.wait()
462            wait_threads_blocked(N)
463            self.assertEqual(len(results1), 0)
464
465            # Threads unblocked
466            evt.set()
467
468        self.assertEqual(results1, [True] * N)
469        self.assertEqual(results2, [True] * N)
470
471    def test_notify(self):
472        evt = self.eventtype()
473        self._check_notify(evt)
474        # Another time, after an explicit clear()
475        evt.set()
476        evt.clear()
477        self._check_notify(evt)
478
479    def test_timeout(self):
480        evt = self.eventtype()
481        results1 = []
482        results2 = []
483        N = 5
484        def f():
485            results1.append(evt.wait(0.0))
486            t1 = time.monotonic()
487            r = evt.wait(0.5)
488            t2 = time.monotonic()
489            results2.append((r, t2 - t1))
490
491        with Bunch(f, N):
492            pass
493
494        self.assertEqual(results1, [False] * N)
495        for r, dt in results2:
496            self.assertFalse(r)
497            self.assertTimeout(dt, 0.5)
498
499        # The event is set
500        results1 = []
501        results2 = []
502        evt.set()
503        with Bunch(f, N):
504            pass
505
506        self.assertEqual(results1, [True] * N)
507        for r, dt in results2:
508            self.assertTrue(r)
509
510    def test_set_and_clear(self):
511        # gh-57711: check that wait() returns true even when the event is
512        # cleared before the waiting thread is woken up.
513        event = self.eventtype()
514        results = []
515        def f():
516            results.append(event.wait(support.LONG_TIMEOUT))
517
518        N = 5
519        with Bunch(f, N):
520            # Threads blocked on event.wait()
521            wait_threads_blocked(N)
522
523            # Threads unblocked
524            event.set()
525            event.clear()
526
527        self.assertEqual(results, [True] * N)
528
529    @requires_fork
530    def test_at_fork_reinit(self):
531        # ensure that condition is still using a Lock after reset
532        evt = self.eventtype()
533        with evt._cond:
534            self.assertFalse(evt._cond.acquire(False))
535        evt._at_fork_reinit()
536        with evt._cond:
537            self.assertFalse(evt._cond.acquire(False))
538
539    def test_repr(self):
540        evt = self.eventtype()
541        self.assertRegex(repr(evt), r"<\w+\.Event at .*: unset>")
542        evt.set()
543        self.assertRegex(repr(evt), r"<\w+\.Event at .*: set>")
544
545
546class ConditionTests(BaseTestCase):
547    """
548    Tests for condition variables.
549    """
550
551    def test_acquire(self):
552        cond = self.condtype()
553        # Be default we have an RLock: the condition can be acquired multiple
554        # times.
555        cond.acquire()
556        cond.acquire()
557        cond.release()
558        cond.release()
559        lock = threading.Lock()
560        cond = self.condtype(lock)
561        cond.acquire()
562        self.assertFalse(lock.acquire(False))
563        cond.release()
564        self.assertTrue(lock.acquire(False))
565        self.assertFalse(cond.acquire(False))
566        lock.release()
567        with cond:
568            self.assertFalse(lock.acquire(False))
569
570    def test_unacquired_wait(self):
571        cond = self.condtype()
572        self.assertRaises(RuntimeError, cond.wait)
573
574    def test_unacquired_notify(self):
575        cond = self.condtype()
576        self.assertRaises(RuntimeError, cond.notify)
577
578    def _check_notify(self, cond):
579        # Note that this test is sensitive to timing.  If the worker threads
580        # don't execute in a timely fashion, the main thread may think they
581        # are further along then they are.  The main thread therefore issues
582        # wait_threads_blocked() statements to try to make sure that it doesn't
583        # race ahead of the workers.
584        # Secondly, this test assumes that condition variables are not subject
585        # to spurious wakeups.  The absence of spurious wakeups is an implementation
586        # detail of Condition Variables in current CPython, but in general, not
587        # a guaranteed property of condition variables as a programming
588        # construct.  In particular, it is possible that this can no longer
589        # be conveniently guaranteed should their implementation ever change.
590        ready = []
591        results1 = []
592        results2 = []
593        phase_num = 0
594        def f():
595            cond.acquire()
596            ready.append(phase_num)
597            result = cond.wait()
598
599            cond.release()
600            results1.append((result, phase_num))
601
602            cond.acquire()
603            ready.append(phase_num)
604
605            result = cond.wait()
606            cond.release()
607            results2.append((result, phase_num))
608
609        N = 5
610        with Bunch(f, N):
611            # first wait, to ensure all workers settle into cond.wait() before
612            # we continue. See issues #8799 and #30727.
613            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
614                if len(ready) >= N:
615                    break
616
617            ready.clear()
618            self.assertEqual(results1, [])
619
620            # Notify 3 threads at first
621            count1 = 3
622            cond.acquire()
623            cond.notify(count1)
624            wait_threads_blocked(count1)
625
626            # Phase 1
627            phase_num = 1
628            cond.release()
629            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
630                if len(results1) >= count1:
631                    break
632
633            self.assertEqual(results1, [(True, 1)] * count1)
634            self.assertEqual(results2, [])
635
636            # Wait until awaken workers are blocked on cond.wait()
637            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
638                if len(ready) >= count1 :
639                    break
640
641            # Notify 5 threads: they might be in their first or second wait
642            cond.acquire()
643            cond.notify(5)
644            wait_threads_blocked(N)
645
646            # Phase 2
647            phase_num = 2
648            cond.release()
649            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
650                if len(results1) + len(results2) >= (N + count1):
651                    break
652
653            count2 = N - count1
654            self.assertEqual(results1, [(True, 1)] * count1 + [(True, 2)] * count2)
655            self.assertEqual(results2, [(True, 2)] * count1)
656
657            # Make sure all workers settle into cond.wait()
658            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
659                if len(ready) >= N:
660                    break
661
662            # Notify all threads: they are all in their second wait
663            cond.acquire()
664            cond.notify_all()
665            wait_threads_blocked(N)
666
667            # Phase 3
668            phase_num = 3
669            cond.release()
670            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
671                if len(results2) >= N:
672                    break
673            self.assertEqual(results1, [(True, 1)] * count1 + [(True, 2)] * count2)
674            self.assertEqual(results2, [(True, 2)] * count1 + [(True, 3)] * count2)
675
676    def test_notify(self):
677        cond = self.condtype()
678        self._check_notify(cond)
679        # A second time, to check internal state is still ok.
680        self._check_notify(cond)
681
682    def test_timeout(self):
683        cond = self.condtype()
684        timeout = 0.5
685        results = []
686        def f():
687            cond.acquire()
688            t1 = time.monotonic()
689            result = cond.wait(timeout)
690            t2 = time.monotonic()
691            cond.release()
692            results.append((t2 - t1, result))
693
694        N = 5
695        with Bunch(f, N):
696            pass
697        self.assertEqual(len(results), N)
698
699        for dt, result in results:
700            self.assertTimeout(dt, timeout)
701            # Note that conceptually (that"s the condition variable protocol)
702            # a wait() may succeed even if no one notifies us and before any
703            # timeout occurs.  Spurious wakeups can occur.
704            # This makes it hard to verify the result value.
705            # In practice, this implementation has no spurious wakeups.
706            self.assertFalse(result)
707
708    def test_waitfor(self):
709        cond = self.condtype()
710        state = 0
711        def f():
712            with cond:
713                result = cond.wait_for(lambda: state == 4)
714                self.assertTrue(result)
715                self.assertEqual(state, 4)
716
717        with Bunch(f, 1):
718            for i in range(4):
719                time.sleep(0.010)
720                with cond:
721                    state += 1
722                    cond.notify()
723
724    def test_waitfor_timeout(self):
725        cond = self.condtype()
726        state = 0
727        success = []
728        def f():
729            with cond:
730                dt = time.monotonic()
731                result = cond.wait_for(lambda : state==4, timeout=0.1)
732                dt = time.monotonic() - dt
733                self.assertFalse(result)
734                self.assertTimeout(dt, 0.1)
735                success.append(None)
736
737        with Bunch(f, 1):
738            # Only increment 3 times, so state == 4 is never reached.
739            for i in range(3):
740                time.sleep(0.010)
741                with cond:
742                    state += 1
743                    cond.notify()
744
745        self.assertEqual(len(success), 1)
746
747
748class BaseSemaphoreTests(BaseTestCase):
749    """
750    Common tests for {bounded, unbounded} semaphore objects.
751    """
752
753    def test_constructor(self):
754        self.assertRaises(ValueError, self.semtype, value = -1)
755        self.assertRaises(ValueError, self.semtype, value = -sys.maxsize)
756
757    def test_acquire(self):
758        sem = self.semtype(1)
759        sem.acquire()
760        sem.release()
761        sem = self.semtype(2)
762        sem.acquire()
763        sem.acquire()
764        sem.release()
765        sem.release()
766
767    def test_acquire_destroy(self):
768        sem = self.semtype()
769        sem.acquire()
770        del sem
771
772    def test_acquire_contended(self):
773        sem_value = 7
774        sem = self.semtype(sem_value)
775        sem.acquire()
776
777        sem_results = []
778        results1 = []
779        results2 = []
780        phase_num = 0
781
782        def func():
783            sem_results.append(sem.acquire())
784            results1.append(phase_num)
785
786            sem_results.append(sem.acquire())
787            results2.append(phase_num)
788
789        def wait_count(count):
790            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
791                if len(results1) + len(results2) >= count:
792                    break
793
794        N = 10
795        with Bunch(func, N):
796            # Phase 0
797            count1 = sem_value - 1
798            wait_count(count1)
799            self.assertEqual(results1 + results2, [0] * count1)
800
801            # Phase 1
802            phase_num = 1
803            for i in range(sem_value):
804                sem.release()
805            count2 = sem_value
806            wait_count(count1 + count2)
807            self.assertEqual(sorted(results1 + results2),
808                             [0] * count1 + [1] * count2)
809
810            # Phase 2
811            phase_num = 2
812            count3 = (sem_value - 1)
813            for i in range(count3):
814                sem.release()
815            wait_count(count1 + count2 + count3)
816            self.assertEqual(sorted(results1 + results2),
817                             [0] * count1 + [1] * count2 + [2] * count3)
818            # The semaphore is still locked
819            self.assertFalse(sem.acquire(False))
820
821            # Final release, to let the last thread finish
822            count4 = 1
823            sem.release()
824
825        self.assertEqual(sem_results,
826                         [True] * (count1 + count2 + count3 + count4))
827
828    def test_multirelease(self):
829        sem_value = 7
830        sem = self.semtype(sem_value)
831        sem.acquire()
832
833        results1 = []
834        results2 = []
835        phase_num = 0
836        def func():
837            sem.acquire()
838            results1.append(phase_num)
839
840            sem.acquire()
841            results2.append(phase_num)
842
843        def wait_count(count):
844            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
845                if len(results1) + len(results2) >= count:
846                    break
847
848        with Bunch(func, 10):
849            # Phase 0
850            count1 = sem_value - 1
851            wait_count(count1)
852            self.assertEqual(results1 + results2, [0] * count1)
853
854            # Phase 1
855            phase_num = 1
856            count2 = sem_value
857            sem.release(count2)
858            wait_count(count1 + count2)
859            self.assertEqual(sorted(results1 + results2),
860                             [0] * count1 + [1] * count2)
861
862            # Phase 2
863            phase_num = 2
864            count3 = sem_value - 1
865            sem.release(count3)
866            wait_count(count1 + count2 + count3)
867            self.assertEqual(sorted(results1 + results2),
868                             [0] * count1 + [1] * count2 + [2] * count3)
869            # The semaphore is still locked
870            self.assertFalse(sem.acquire(False))
871
872            # Final release, to let the last thread finish
873            sem.release()
874
875    def test_try_acquire(self):
876        sem = self.semtype(2)
877        self.assertTrue(sem.acquire(False))
878        self.assertTrue(sem.acquire(False))
879        self.assertFalse(sem.acquire(False))
880        sem.release()
881        self.assertTrue(sem.acquire(False))
882
883    def test_try_acquire_contended(self):
884        sem = self.semtype(4)
885        sem.acquire()
886        results = []
887        def f():
888            results.append(sem.acquire(False))
889            results.append(sem.acquire(False))
890        with Bunch(f, 5):
891            pass
892        # There can be a thread switch between acquiring the semaphore and
893        # appending the result, therefore results will not necessarily be
894        # ordered.
895        self.assertEqual(sorted(results), [False] * 7 + [True] *  3 )
896
897    def test_acquire_timeout(self):
898        sem = self.semtype(2)
899        self.assertRaises(ValueError, sem.acquire, False, timeout=1.0)
900        self.assertTrue(sem.acquire(timeout=0.005))
901        self.assertTrue(sem.acquire(timeout=0.005))
902        self.assertFalse(sem.acquire(timeout=0.005))
903        sem.release()
904        self.assertTrue(sem.acquire(timeout=0.005))
905        t = time.monotonic()
906        self.assertFalse(sem.acquire(timeout=0.5))
907        dt = time.monotonic() - t
908        self.assertTimeout(dt, 0.5)
909
910    def test_default_value(self):
911        # The default initial value is 1.
912        sem = self.semtype()
913        sem.acquire()
914        def f():
915            sem.acquire()
916            sem.release()
917
918        with Bunch(f, 1) as bunch:
919            # Thread blocked on sem.acquire()
920            wait_threads_blocked(1)
921            self.assertFalse(bunch.finished)
922
923            # Thread unblocked
924            sem.release()
925
926    def test_with(self):
927        sem = self.semtype(2)
928        def _with(err=None):
929            with sem:
930                self.assertTrue(sem.acquire(False))
931                sem.release()
932                with sem:
933                    self.assertFalse(sem.acquire(False))
934                    if err:
935                        raise err
936        _with()
937        self.assertTrue(sem.acquire(False))
938        sem.release()
939        self.assertRaises(TypeError, _with, TypeError)
940        self.assertTrue(sem.acquire(False))
941        sem.release()
942
943class SemaphoreTests(BaseSemaphoreTests):
944    """
945    Tests for unbounded semaphores.
946    """
947
948    def test_release_unacquired(self):
949        # Unbounded releases are allowed and increment the semaphore's value
950        sem = self.semtype(1)
951        sem.release()
952        sem.acquire()
953        sem.acquire()
954        sem.release()
955
956    def test_repr(self):
957        sem = self.semtype(3)
958        self.assertRegex(repr(sem), r"<\w+\.Semaphore at .*: value=3>")
959        sem.acquire()
960        self.assertRegex(repr(sem), r"<\w+\.Semaphore at .*: value=2>")
961        sem.release()
962        sem.release()
963        self.assertRegex(repr(sem), r"<\w+\.Semaphore at .*: value=4>")
964
965
966class BoundedSemaphoreTests(BaseSemaphoreTests):
967    """
968    Tests for bounded semaphores.
969    """
970
971    def test_release_unacquired(self):
972        # Cannot go past the initial value
973        sem = self.semtype()
974        self.assertRaises(ValueError, sem.release)
975        sem.acquire()
976        sem.release()
977        self.assertRaises(ValueError, sem.release)
978
979    def test_repr(self):
980        sem = self.semtype(3)
981        self.assertRegex(repr(sem), r"<\w+\.BoundedSemaphore at .*: value=3/3>")
982        sem.acquire()
983        self.assertRegex(repr(sem), r"<\w+\.BoundedSemaphore at .*: value=2/3>")
984
985
986class BarrierTests(BaseTestCase):
987    """
988    Tests for Barrier objects.
989    """
990    N = 5
991    defaultTimeout = 2.0
992
993    def setUp(self):
994        self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
995
996    def tearDown(self):
997        self.barrier.abort()
998
999    def run_threads(self, f):
1000        with Bunch(f, self.N):
1001            pass
1002
1003    def multipass(self, results, n):
1004        m = self.barrier.parties
1005        self.assertEqual(m, self.N)
1006        for i in range(n):
1007            results[0].append(True)
1008            self.assertEqual(len(results[1]), i * m)
1009            self.barrier.wait()
1010            results[1].append(True)
1011            self.assertEqual(len(results[0]), (i + 1) * m)
1012            self.barrier.wait()
1013        self.assertEqual(self.barrier.n_waiting, 0)
1014        self.assertFalse(self.barrier.broken)
1015
1016    def test_constructor(self):
1017        self.assertRaises(ValueError, self.barriertype, parties=0)
1018        self.assertRaises(ValueError, self.barriertype, parties=-1)
1019
1020    def test_barrier(self, passes=1):
1021        """
1022        Test that a barrier is passed in lockstep
1023        """
1024        results = [[],[]]
1025        def f():
1026            self.multipass(results, passes)
1027        self.run_threads(f)
1028
1029    def test_barrier_10(self):
1030        """
1031        Test that a barrier works for 10 consecutive runs
1032        """
1033        return self.test_barrier(10)
1034
1035    def test_wait_return(self):
1036        """
1037        test the return value from barrier.wait
1038        """
1039        results = []
1040        def f():
1041            r = self.barrier.wait()
1042            results.append(r)
1043
1044        self.run_threads(f)
1045        self.assertEqual(sum(results), sum(range(self.N)))
1046
1047    def test_action(self):
1048        """
1049        Test the 'action' callback
1050        """
1051        results = []
1052        def action():
1053            results.append(True)
1054        barrier = self.barriertype(self.N, action)
1055        def f():
1056            barrier.wait()
1057            self.assertEqual(len(results), 1)
1058
1059        self.run_threads(f)
1060
1061    def test_abort(self):
1062        """
1063        Test that an abort will put the barrier in a broken state
1064        """
1065        results1 = []
1066        results2 = []
1067        def f():
1068            try:
1069                i = self.barrier.wait()
1070                if i == self.N//2:
1071                    raise RuntimeError
1072                self.barrier.wait()
1073                results1.append(True)
1074            except threading.BrokenBarrierError:
1075                results2.append(True)
1076            except RuntimeError:
1077                self.barrier.abort()
1078                pass
1079
1080        self.run_threads(f)
1081        self.assertEqual(len(results1), 0)
1082        self.assertEqual(len(results2), self.N-1)
1083        self.assertTrue(self.barrier.broken)
1084
1085    def test_reset(self):
1086        """
1087        Test that a 'reset' on a barrier frees the waiting threads
1088        """
1089        results1 = []
1090        results2 = []
1091        results3 = []
1092        def f():
1093            i = self.barrier.wait()
1094            if i == self.N//2:
1095                # Wait until the other threads are all in the barrier.
1096                for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
1097                    if self.barrier.n_waiting >= (self.N - 1):
1098                        break
1099                self.barrier.reset()
1100            else:
1101                try:
1102                    self.barrier.wait()
1103                    results1.append(True)
1104                except threading.BrokenBarrierError:
1105                    results2.append(True)
1106            # Now, pass the barrier again
1107            self.barrier.wait()
1108            results3.append(True)
1109
1110        self.run_threads(f)
1111        self.assertEqual(len(results1), 0)
1112        self.assertEqual(len(results2), self.N-1)
1113        self.assertEqual(len(results3), self.N)
1114
1115
1116    def test_abort_and_reset(self):
1117        """
1118        Test that a barrier can be reset after being broken.
1119        """
1120        results1 = []
1121        results2 = []
1122        results3 = []
1123        barrier2 = self.barriertype(self.N)
1124        def f():
1125            try:
1126                i = self.barrier.wait()
1127                if i == self.N//2:
1128                    raise RuntimeError
1129                self.barrier.wait()
1130                results1.append(True)
1131            except threading.BrokenBarrierError:
1132                results2.append(True)
1133            except RuntimeError:
1134                self.barrier.abort()
1135                pass
1136            # Synchronize and reset the barrier.  Must synchronize first so
1137            # that everyone has left it when we reset, and after so that no
1138            # one enters it before the reset.
1139            if barrier2.wait() == self.N//2:
1140                self.barrier.reset()
1141            barrier2.wait()
1142            self.barrier.wait()
1143            results3.append(True)
1144
1145        self.run_threads(f)
1146        self.assertEqual(len(results1), 0)
1147        self.assertEqual(len(results2), self.N-1)
1148        self.assertEqual(len(results3), self.N)
1149
1150    def test_timeout(self):
1151        """
1152        Test wait(timeout)
1153        """
1154        def f():
1155            i = self.barrier.wait()
1156            if i == self.N // 2:
1157                # One thread is late!
1158                time.sleep(self.defaultTimeout / 2)
1159            # Default timeout is 2.0, so this is shorter.
1160            self.assertRaises(threading.BrokenBarrierError,
1161                              self.barrier.wait, self.defaultTimeout / 4)
1162        self.run_threads(f)
1163
1164    def test_default_timeout(self):
1165        """
1166        Test the barrier's default timeout
1167        """
1168        timeout = 0.100
1169        barrier = self.barriertype(2, timeout=timeout)
1170        def f():
1171            self.assertRaises(threading.BrokenBarrierError,
1172                              barrier.wait)
1173
1174        start_time = time.monotonic()
1175        with Bunch(f, 1):
1176            pass
1177        dt = time.monotonic() - start_time
1178        self.assertGreaterEqual(dt, timeout)
1179
1180    def test_single_thread(self):
1181        b = self.barriertype(1)
1182        b.wait()
1183        b.wait()
1184
1185    def test_repr(self):
1186        barrier = self.barriertype(3)
1187        timeout = support.LONG_TIMEOUT
1188        self.assertRegex(repr(barrier), r"<\w+\.Barrier at .*: waiters=0/3>")
1189        def f():
1190            barrier.wait(timeout)
1191
1192        N = 2
1193        with Bunch(f, N):
1194            # Threads blocked on barrier.wait()
1195            for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
1196                if barrier.n_waiting >= N:
1197                    break
1198            self.assertRegex(repr(barrier),
1199                             r"<\w+\.Barrier at .*: waiters=2/3>")
1200
1201            # Threads unblocked
1202            barrier.wait(timeout)
1203
1204        self.assertRegex(repr(barrier),
1205                         r"<\w+\.Barrier at .*: waiters=0/3>")
1206
1207        # Abort the barrier
1208        barrier.abort()
1209        self.assertRegex(repr(barrier),
1210                         r"<\w+\.Barrier at .*: broken>")
1211