• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Various tests for synchronization primitives.
3"""
4
5import os
6import gc
7import sys
8import time
9from _thread import start_new_thread, TIMEOUT_MAX
10import threading
11import unittest
12import weakref
13
14from test import support
15from test.support import threading_helper
16
17
18requires_fork = unittest.skipUnless(hasattr(os, 'fork'),
19                                    "platform doesn't support fork "
20                                     "(no _at_fork_reinit method)")
21
22
23def _wait():
24    # A crude wait/yield function not relying on synchronization primitives.
25    time.sleep(0.01)
26
27class Bunch(object):
28    """
29    A bunch of threads.
30    """
31    def __init__(self, f, n, wait_before_exit=False):
32        """
33        Construct a bunch of `n` threads running the same function `f`.
34        If `wait_before_exit` is True, the threads won't terminate until
35        do_finish() is called.
36        """
37        self.f = f
38        self.n = n
39        self.started = []
40        self.finished = []
41        self._can_exit = not wait_before_exit
42        self.wait_thread = threading_helper.wait_threads_exit()
43        self.wait_thread.__enter__()
44
45        def task():
46            tid = threading.get_ident()
47            self.started.append(tid)
48            try:
49                f()
50            finally:
51                self.finished.append(tid)
52                while not self._can_exit:
53                    _wait()
54
55        try:
56            for i in range(n):
57                start_new_thread(task, ())
58        except:
59            self._can_exit = True
60            raise
61
62    def wait_for_started(self):
63        while len(self.started) < self.n:
64            _wait()
65
66    def wait_for_finished(self):
67        while len(self.finished) < self.n:
68            _wait()
69        # Wait for threads exit
70        self.wait_thread.__exit__(None, None, None)
71
72    def do_finish(self):
73        self._can_exit = True
74
75
76class BaseTestCase(unittest.TestCase):
77    def setUp(self):
78        self._threads = threading_helper.threading_setup()
79
80    def tearDown(self):
81        threading_helper.threading_cleanup(*self._threads)
82        support.reap_children()
83
84    def assertTimeout(self, actual, expected):
85        # The waiting and/or time.monotonic() can be imprecise, which
86        # is why comparing to the expected value would sometimes fail
87        # (especially under Windows).
88        self.assertGreaterEqual(actual, expected * 0.6)
89        # Test nothing insane happened
90        self.assertLess(actual, expected * 10.0)
91
92
93class BaseLockTests(BaseTestCase):
94    """
95    Tests for both recursive and non-recursive locks.
96    """
97
98    def test_constructor(self):
99        lock = self.locktype()
100        del lock
101
102    def test_repr(self):
103        lock = self.locktype()
104        self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>")
105        del lock
106
107    def test_locked_repr(self):
108        lock = self.locktype()
109        lock.acquire()
110        self.assertRegex(repr(lock), "<locked .* object (.*)?at .*>")
111        del lock
112
113    def test_acquire_destroy(self):
114        lock = self.locktype()
115        lock.acquire()
116        del lock
117
118    def test_acquire_release(self):
119        lock = self.locktype()
120        lock.acquire()
121        lock.release()
122        del lock
123
124    def test_try_acquire(self):
125        lock = self.locktype()
126        self.assertTrue(lock.acquire(False))
127        lock.release()
128
129    def test_try_acquire_contended(self):
130        lock = self.locktype()
131        lock.acquire()
132        result = []
133        def f():
134            result.append(lock.acquire(False))
135        Bunch(f, 1).wait_for_finished()
136        self.assertFalse(result[0])
137        lock.release()
138
139    def test_acquire_contended(self):
140        lock = self.locktype()
141        lock.acquire()
142        N = 5
143        def f():
144            lock.acquire()
145            lock.release()
146
147        b = Bunch(f, N)
148        b.wait_for_started()
149        _wait()
150        self.assertEqual(len(b.finished), 0)
151        lock.release()
152        b.wait_for_finished()
153        self.assertEqual(len(b.finished), N)
154
155    def test_with(self):
156        lock = self.locktype()
157        def f():
158            lock.acquire()
159            lock.release()
160        def _with(err=None):
161            with lock:
162                if err is not None:
163                    raise err
164        _with()
165        # Check the lock is unacquired
166        Bunch(f, 1).wait_for_finished()
167        self.assertRaises(TypeError, _with, TypeError)
168        # Check the lock is unacquired
169        Bunch(f, 1).wait_for_finished()
170
171    def test_thread_leak(self):
172        # The lock shouldn't leak a Thread instance when used from a foreign
173        # (non-threading) thread.
174        lock = self.locktype()
175        def f():
176            lock.acquire()
177            lock.release()
178        n = len(threading.enumerate())
179        # We run many threads in the hope that existing threads ids won't
180        # be recycled.
181        Bunch(f, 15).wait_for_finished()
182        if len(threading.enumerate()) != n:
183            # There is a small window during which a Thread instance's
184            # target function has finished running, but the Thread is still
185            # alive and registered.  Avoid spurious failures by waiting a
186            # bit more (seen on a buildbot).
187            time.sleep(0.4)
188            self.assertEqual(n, len(threading.enumerate()))
189
190    def test_timeout(self):
191        lock = self.locktype()
192        # Can't set timeout if not blocking
193        self.assertRaises(ValueError, lock.acquire, False, 1)
194        # Invalid timeout values
195        self.assertRaises(ValueError, lock.acquire, timeout=-100)
196        self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
197        self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
198        # TIMEOUT_MAX is ok
199        lock.acquire(timeout=TIMEOUT_MAX)
200        lock.release()
201        t1 = time.monotonic()
202        self.assertTrue(lock.acquire(timeout=5))
203        t2 = time.monotonic()
204        # Just a sanity test that it didn't actually wait for the timeout.
205        self.assertLess(t2 - t1, 5)
206        results = []
207        def f():
208            t1 = time.monotonic()
209            results.append(lock.acquire(timeout=0.5))
210            t2 = time.monotonic()
211            results.append(t2 - t1)
212        Bunch(f, 1).wait_for_finished()
213        self.assertFalse(results[0])
214        self.assertTimeout(results[1], 0.5)
215
216    def test_weakref_exists(self):
217        lock = self.locktype()
218        ref = weakref.ref(lock)
219        self.assertIsNotNone(ref())
220
221    def test_weakref_deleted(self):
222        lock = self.locktype()
223        ref = weakref.ref(lock)
224        del lock
225        gc.collect()  # For PyPy or other GCs.
226        self.assertIsNone(ref())
227
228
229class LockTests(BaseLockTests):
230    """
231    Tests for non-recursive, weak locks
232    (which can be acquired and released from different threads).
233    """
234    def test_reacquire(self):
235        # Lock needs to be released before re-acquiring.
236        lock = self.locktype()
237        phase = []
238
239        def f():
240            lock.acquire()
241            phase.append(None)
242            lock.acquire()
243            phase.append(None)
244
245        with threading_helper.wait_threads_exit():
246            start_new_thread(f, ())
247            while len(phase) == 0:
248                _wait()
249            _wait()
250            self.assertEqual(len(phase), 1)
251            lock.release()
252            while len(phase) == 1:
253                _wait()
254            self.assertEqual(len(phase), 2)
255
256    def test_different_thread(self):
257        # Lock can be released from a different thread.
258        lock = self.locktype()
259        lock.acquire()
260        def f():
261            lock.release()
262        b = Bunch(f, 1)
263        b.wait_for_finished()
264        lock.acquire()
265        lock.release()
266
267    def test_state_after_timeout(self):
268        # Issue #11618: check that lock is in a proper state after a
269        # (non-zero) timeout.
270        lock = self.locktype()
271        lock.acquire()
272        self.assertFalse(lock.acquire(timeout=0.01))
273        lock.release()
274        self.assertFalse(lock.locked())
275        self.assertTrue(lock.acquire(blocking=False))
276
277    @requires_fork
278    def test_at_fork_reinit(self):
279        def use_lock(lock):
280            # make sure that the lock still works normally
281            # after _at_fork_reinit()
282            lock.acquire()
283            lock.release()
284
285        # unlocked
286        lock = self.locktype()
287        lock._at_fork_reinit()
288        use_lock(lock)
289
290        # locked: _at_fork_reinit() resets the lock to the unlocked state
291        lock2 = self.locktype()
292        lock2.acquire()
293        lock2._at_fork_reinit()
294        use_lock(lock2)
295
296
297class RLockTests(BaseLockTests):
298    """
299    Tests for recursive locks.
300    """
301    def test_reacquire(self):
302        lock = self.locktype()
303        lock.acquire()
304        lock.acquire()
305        lock.release()
306        lock.acquire()
307        lock.release()
308        lock.release()
309
310    def test_release_unacquired(self):
311        # Cannot release an unacquired lock
312        lock = self.locktype()
313        self.assertRaises(RuntimeError, lock.release)
314        lock.acquire()
315        lock.acquire()
316        lock.release()
317        lock.acquire()
318        lock.release()
319        lock.release()
320        self.assertRaises(RuntimeError, lock.release)
321
322    def test_release_save_unacquired(self):
323        # Cannot _release_save an unacquired lock
324        lock = self.locktype()
325        self.assertRaises(RuntimeError, lock._release_save)
326        lock.acquire()
327        lock.acquire()
328        lock.release()
329        lock.acquire()
330        lock.release()
331        lock.release()
332        self.assertRaises(RuntimeError, lock._release_save)
333
334    def test_different_thread(self):
335        # Cannot release from a different thread
336        lock = self.locktype()
337        def f():
338            lock.acquire()
339        b = Bunch(f, 1, True)
340        try:
341            self.assertRaises(RuntimeError, lock.release)
342        finally:
343            b.do_finish()
344        b.wait_for_finished()
345
346    def test__is_owned(self):
347        lock = self.locktype()
348        self.assertFalse(lock._is_owned())
349        lock.acquire()
350        self.assertTrue(lock._is_owned())
351        lock.acquire()
352        self.assertTrue(lock._is_owned())
353        result = []
354        def f():
355            result.append(lock._is_owned())
356        Bunch(f, 1).wait_for_finished()
357        self.assertFalse(result[0])
358        lock.release()
359        self.assertTrue(lock._is_owned())
360        lock.release()
361        self.assertFalse(lock._is_owned())
362
363
364class EventTests(BaseTestCase):
365    """
366    Tests for Event objects.
367    """
368
369    def test_is_set(self):
370        evt = self.eventtype()
371        self.assertFalse(evt.is_set())
372        evt.set()
373        self.assertTrue(evt.is_set())
374        evt.set()
375        self.assertTrue(evt.is_set())
376        evt.clear()
377        self.assertFalse(evt.is_set())
378        evt.clear()
379        self.assertFalse(evt.is_set())
380
381    def _check_notify(self, evt):
382        # All threads get notified
383        N = 5
384        results1 = []
385        results2 = []
386        def f():
387            results1.append(evt.wait())
388            results2.append(evt.wait())
389        b = Bunch(f, N)
390        b.wait_for_started()
391        _wait()
392        self.assertEqual(len(results1), 0)
393        evt.set()
394        b.wait_for_finished()
395        self.assertEqual(results1, [True] * N)
396        self.assertEqual(results2, [True] * N)
397
398    def test_notify(self):
399        evt = self.eventtype()
400        self._check_notify(evt)
401        # Another time, after an explicit clear()
402        evt.set()
403        evt.clear()
404        self._check_notify(evt)
405
406    def test_timeout(self):
407        evt = self.eventtype()
408        results1 = []
409        results2 = []
410        N = 5
411        def f():
412            results1.append(evt.wait(0.0))
413            t1 = time.monotonic()
414            r = evt.wait(0.5)
415            t2 = time.monotonic()
416            results2.append((r, t2 - t1))
417        Bunch(f, N).wait_for_finished()
418        self.assertEqual(results1, [False] * N)
419        for r, dt in results2:
420            self.assertFalse(r)
421            self.assertTimeout(dt, 0.5)
422        # The event is set
423        results1 = []
424        results2 = []
425        evt.set()
426        Bunch(f, N).wait_for_finished()
427        self.assertEqual(results1, [True] * N)
428        for r, dt in results2:
429            self.assertTrue(r)
430
431    def test_set_and_clear(self):
432        # Issue #13502: check that wait() returns true even when the event is
433        # cleared before the waiting thread is woken up.
434        evt = self.eventtype()
435        results = []
436        timeout = 0.250
437        N = 5
438        def f():
439            results.append(evt.wait(timeout * 4))
440        b = Bunch(f, N)
441        b.wait_for_started()
442        time.sleep(timeout)
443        evt.set()
444        evt.clear()
445        b.wait_for_finished()
446        self.assertEqual(results, [True] * N)
447
448    @requires_fork
449    def test_at_fork_reinit(self):
450        # ensure that condition is still using a Lock after reset
451        evt = self.eventtype()
452        with evt._cond:
453            self.assertFalse(evt._cond.acquire(False))
454        evt._at_fork_reinit()
455        with evt._cond:
456            self.assertFalse(evt._cond.acquire(False))
457
458
459class ConditionTests(BaseTestCase):
460    """
461    Tests for condition variables.
462    """
463
464    def test_acquire(self):
465        cond = self.condtype()
466        # Be default we have an RLock: the condition can be acquired multiple
467        # times.
468        cond.acquire()
469        cond.acquire()
470        cond.release()
471        cond.release()
472        lock = threading.Lock()
473        cond = self.condtype(lock)
474        cond.acquire()
475        self.assertFalse(lock.acquire(False))
476        cond.release()
477        self.assertTrue(lock.acquire(False))
478        self.assertFalse(cond.acquire(False))
479        lock.release()
480        with cond:
481            self.assertFalse(lock.acquire(False))
482
483    def test_unacquired_wait(self):
484        cond = self.condtype()
485        self.assertRaises(RuntimeError, cond.wait)
486
487    def test_unacquired_notify(self):
488        cond = self.condtype()
489        self.assertRaises(RuntimeError, cond.notify)
490
491    def _check_notify(self, cond):
492        # Note that this test is sensitive to timing.  If the worker threads
493        # don't execute in a timely fashion, the main thread may think they
494        # are further along then they are.  The main thread therefore issues
495        # _wait() statements to try to make sure that it doesn't race ahead
496        # of the workers.
497        # Secondly, this test assumes that condition variables are not subject
498        # to spurious wakeups.  The absence of spurious wakeups is an implementation
499        # detail of Condition Variables in current CPython, but in general, not
500        # a guaranteed property of condition variables as a programming
501        # construct.  In particular, it is possible that this can no longer
502        # be conveniently guaranteed should their implementation ever change.
503        N = 5
504        ready = []
505        results1 = []
506        results2 = []
507        phase_num = 0
508        def f():
509            cond.acquire()
510            ready.append(phase_num)
511            result = cond.wait()
512            cond.release()
513            results1.append((result, phase_num))
514            cond.acquire()
515            ready.append(phase_num)
516            result = cond.wait()
517            cond.release()
518            results2.append((result, phase_num))
519        b = Bunch(f, N)
520        b.wait_for_started()
521        # first wait, to ensure all workers settle into cond.wait() before
522        # we continue. See issues #8799 and #30727.
523        while len(ready) < 5:
524            _wait()
525        ready.clear()
526        self.assertEqual(results1, [])
527        # Notify 3 threads at first
528        cond.acquire()
529        cond.notify(3)
530        _wait()
531        phase_num = 1
532        cond.release()
533        while len(results1) < 3:
534            _wait()
535        self.assertEqual(results1, [(True, 1)] * 3)
536        self.assertEqual(results2, [])
537        # make sure all awaken workers settle into cond.wait()
538        while len(ready) < 3:
539            _wait()
540        # Notify 5 threads: they might be in their first or second wait
541        cond.acquire()
542        cond.notify(5)
543        _wait()
544        phase_num = 2
545        cond.release()
546        while len(results1) + len(results2) < 8:
547            _wait()
548        self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
549        self.assertEqual(results2, [(True, 2)] * 3)
550        # make sure all workers settle into cond.wait()
551        while len(ready) < 5:
552            _wait()
553        # Notify all threads: they are all in their second wait
554        cond.acquire()
555        cond.notify_all()
556        _wait()
557        phase_num = 3
558        cond.release()
559        while len(results2) < 5:
560            _wait()
561        self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2)
562        self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2)
563        b.wait_for_finished()
564
565    def test_notify(self):
566        cond = self.condtype()
567        self._check_notify(cond)
568        # A second time, to check internal state is still ok.
569        self._check_notify(cond)
570
571    def test_timeout(self):
572        cond = self.condtype()
573        results = []
574        N = 5
575        def f():
576            cond.acquire()
577            t1 = time.monotonic()
578            result = cond.wait(0.5)
579            t2 = time.monotonic()
580            cond.release()
581            results.append((t2 - t1, result))
582        Bunch(f, N).wait_for_finished()
583        self.assertEqual(len(results), N)
584        for dt, result in results:
585            self.assertTimeout(dt, 0.5)
586            # Note that conceptually (that"s the condition variable protocol)
587            # a wait() may succeed even if no one notifies us and before any
588            # timeout occurs.  Spurious wakeups can occur.
589            # This makes it hard to verify the result value.
590            # In practice, this implementation has no spurious wakeups.
591            self.assertFalse(result)
592
593    def test_waitfor(self):
594        cond = self.condtype()
595        state = 0
596        def f():
597            with cond:
598                result = cond.wait_for(lambda : state==4)
599                self.assertTrue(result)
600                self.assertEqual(state, 4)
601        b = Bunch(f, 1)
602        b.wait_for_started()
603        for i in range(4):
604            time.sleep(0.01)
605            with cond:
606                state += 1
607                cond.notify()
608        b.wait_for_finished()
609
610    def test_waitfor_timeout(self):
611        cond = self.condtype()
612        state = 0
613        success = []
614        def f():
615            with cond:
616                dt = time.monotonic()
617                result = cond.wait_for(lambda : state==4, timeout=0.1)
618                dt = time.monotonic() - dt
619                self.assertFalse(result)
620                self.assertTimeout(dt, 0.1)
621                success.append(None)
622        b = Bunch(f, 1)
623        b.wait_for_started()
624        # Only increment 3 times, so state == 4 is never reached.
625        for i in range(3):
626            time.sleep(0.01)
627            with cond:
628                state += 1
629                cond.notify()
630        b.wait_for_finished()
631        self.assertEqual(len(success), 1)
632
633
634class BaseSemaphoreTests(BaseTestCase):
635    """
636    Common tests for {bounded, unbounded} semaphore objects.
637    """
638
639    def test_constructor(self):
640        self.assertRaises(ValueError, self.semtype, value = -1)
641        self.assertRaises(ValueError, self.semtype, value = -sys.maxsize)
642
643    def test_acquire(self):
644        sem = self.semtype(1)
645        sem.acquire()
646        sem.release()
647        sem = self.semtype(2)
648        sem.acquire()
649        sem.acquire()
650        sem.release()
651        sem.release()
652
653    def test_acquire_destroy(self):
654        sem = self.semtype()
655        sem.acquire()
656        del sem
657
658    def test_acquire_contended(self):
659        sem = self.semtype(7)
660        sem.acquire()
661        N = 10
662        sem_results = []
663        results1 = []
664        results2 = []
665        phase_num = 0
666        def f():
667            sem_results.append(sem.acquire())
668            results1.append(phase_num)
669            sem_results.append(sem.acquire())
670            results2.append(phase_num)
671        b = Bunch(f, 10)
672        b.wait_for_started()
673        while len(results1) + len(results2) < 6:
674            _wait()
675        self.assertEqual(results1 + results2, [0] * 6)
676        phase_num = 1
677        for i in range(7):
678            sem.release()
679        while len(results1) + len(results2) < 13:
680            _wait()
681        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
682        phase_num = 2
683        for i in range(6):
684            sem.release()
685        while len(results1) + len(results2) < 19:
686            _wait()
687        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
688        # The semaphore is still locked
689        self.assertFalse(sem.acquire(False))
690        # Final release, to let the last thread finish
691        sem.release()
692        b.wait_for_finished()
693        self.assertEqual(sem_results, [True] * (6 + 7 + 6 + 1))
694
695    def test_multirelease(self):
696        sem = self.semtype(7)
697        sem.acquire()
698        results1 = []
699        results2 = []
700        phase_num = 0
701        def f():
702            sem.acquire()
703            results1.append(phase_num)
704            sem.acquire()
705            results2.append(phase_num)
706        b = Bunch(f, 10)
707        b.wait_for_started()
708        while len(results1) + len(results2) < 6:
709            _wait()
710        self.assertEqual(results1 + results2, [0] * 6)
711        phase_num = 1
712        sem.release(7)
713        while len(results1) + len(results2) < 13:
714            _wait()
715        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
716        phase_num = 2
717        sem.release(6)
718        while len(results1) + len(results2) < 19:
719            _wait()
720        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
721        # The semaphore is still locked
722        self.assertFalse(sem.acquire(False))
723        # Final release, to let the last thread finish
724        sem.release()
725        b.wait_for_finished()
726
727    def test_try_acquire(self):
728        sem = self.semtype(2)
729        self.assertTrue(sem.acquire(False))
730        self.assertTrue(sem.acquire(False))
731        self.assertFalse(sem.acquire(False))
732        sem.release()
733        self.assertTrue(sem.acquire(False))
734
735    def test_try_acquire_contended(self):
736        sem = self.semtype(4)
737        sem.acquire()
738        results = []
739        def f():
740            results.append(sem.acquire(False))
741            results.append(sem.acquire(False))
742        Bunch(f, 5).wait_for_finished()
743        # There can be a thread switch between acquiring the semaphore and
744        # appending the result, therefore results will not necessarily be
745        # ordered.
746        self.assertEqual(sorted(results), [False] * 7 + [True] *  3 )
747
748    def test_acquire_timeout(self):
749        sem = self.semtype(2)
750        self.assertRaises(ValueError, sem.acquire, False, timeout=1.0)
751        self.assertTrue(sem.acquire(timeout=0.005))
752        self.assertTrue(sem.acquire(timeout=0.005))
753        self.assertFalse(sem.acquire(timeout=0.005))
754        sem.release()
755        self.assertTrue(sem.acquire(timeout=0.005))
756        t = time.monotonic()
757        self.assertFalse(sem.acquire(timeout=0.5))
758        dt = time.monotonic() - t
759        self.assertTimeout(dt, 0.5)
760
761    def test_default_value(self):
762        # The default initial value is 1.
763        sem = self.semtype()
764        sem.acquire()
765        def f():
766            sem.acquire()
767            sem.release()
768        b = Bunch(f, 1)
769        b.wait_for_started()
770        _wait()
771        self.assertFalse(b.finished)
772        sem.release()
773        b.wait_for_finished()
774
775    def test_with(self):
776        sem = self.semtype(2)
777        def _with(err=None):
778            with sem:
779                self.assertTrue(sem.acquire(False))
780                sem.release()
781                with sem:
782                    self.assertFalse(sem.acquire(False))
783                    if err:
784                        raise err
785        _with()
786        self.assertTrue(sem.acquire(False))
787        sem.release()
788        self.assertRaises(TypeError, _with, TypeError)
789        self.assertTrue(sem.acquire(False))
790        sem.release()
791
792class SemaphoreTests(BaseSemaphoreTests):
793    """
794    Tests for unbounded semaphores.
795    """
796
797    def test_release_unacquired(self):
798        # Unbounded releases are allowed and increment the semaphore's value
799        sem = self.semtype(1)
800        sem.release()
801        sem.acquire()
802        sem.acquire()
803        sem.release()
804
805
806class BoundedSemaphoreTests(BaseSemaphoreTests):
807    """
808    Tests for bounded semaphores.
809    """
810
811    def test_release_unacquired(self):
812        # Cannot go past the initial value
813        sem = self.semtype()
814        self.assertRaises(ValueError, sem.release)
815        sem.acquire()
816        sem.release()
817        self.assertRaises(ValueError, sem.release)
818
819
820class BarrierTests(BaseTestCase):
821    """
822    Tests for Barrier objects.
823    """
824    N = 5
825    defaultTimeout = 2.0
826
827    def setUp(self):
828        self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
829    def tearDown(self):
830        self.barrier.abort()
831
832    def run_threads(self, f):
833        b = Bunch(f, self.N-1)
834        f()
835        b.wait_for_finished()
836
837    def multipass(self, results, n):
838        m = self.barrier.parties
839        self.assertEqual(m, self.N)
840        for i in range(n):
841            results[0].append(True)
842            self.assertEqual(len(results[1]), i * m)
843            self.barrier.wait()
844            results[1].append(True)
845            self.assertEqual(len(results[0]), (i + 1) * m)
846            self.barrier.wait()
847        self.assertEqual(self.barrier.n_waiting, 0)
848        self.assertFalse(self.barrier.broken)
849
850    def test_barrier(self, passes=1):
851        """
852        Test that a barrier is passed in lockstep
853        """
854        results = [[],[]]
855        def f():
856            self.multipass(results, passes)
857        self.run_threads(f)
858
859    def test_barrier_10(self):
860        """
861        Test that a barrier works for 10 consecutive runs
862        """
863        return self.test_barrier(10)
864
865    def test_wait_return(self):
866        """
867        test the return value from barrier.wait
868        """
869        results = []
870        def f():
871            r = self.barrier.wait()
872            results.append(r)
873
874        self.run_threads(f)
875        self.assertEqual(sum(results), sum(range(self.N)))
876
877    def test_action(self):
878        """
879        Test the 'action' callback
880        """
881        results = []
882        def action():
883            results.append(True)
884        barrier = self.barriertype(self.N, action)
885        def f():
886            barrier.wait()
887            self.assertEqual(len(results), 1)
888
889        self.run_threads(f)
890
891    def test_abort(self):
892        """
893        Test that an abort will put the barrier in a broken state
894        """
895        results1 = []
896        results2 = []
897        def f():
898            try:
899                i = self.barrier.wait()
900                if i == self.N//2:
901                    raise RuntimeError
902                self.barrier.wait()
903                results1.append(True)
904            except threading.BrokenBarrierError:
905                results2.append(True)
906            except RuntimeError:
907                self.barrier.abort()
908                pass
909
910        self.run_threads(f)
911        self.assertEqual(len(results1), 0)
912        self.assertEqual(len(results2), self.N-1)
913        self.assertTrue(self.barrier.broken)
914
915    def test_reset(self):
916        """
917        Test that a 'reset' on a barrier frees the waiting threads
918        """
919        results1 = []
920        results2 = []
921        results3 = []
922        def f():
923            i = self.barrier.wait()
924            if i == self.N//2:
925                # Wait until the other threads are all in the barrier.
926                while self.barrier.n_waiting < self.N-1:
927                    time.sleep(0.001)
928                self.barrier.reset()
929            else:
930                try:
931                    self.barrier.wait()
932                    results1.append(True)
933                except threading.BrokenBarrierError:
934                    results2.append(True)
935            # Now, pass the barrier again
936            self.barrier.wait()
937            results3.append(True)
938
939        self.run_threads(f)
940        self.assertEqual(len(results1), 0)
941        self.assertEqual(len(results2), self.N-1)
942        self.assertEqual(len(results3), self.N)
943
944
945    def test_abort_and_reset(self):
946        """
947        Test that a barrier can be reset after being broken.
948        """
949        results1 = []
950        results2 = []
951        results3 = []
952        barrier2 = self.barriertype(self.N)
953        def f():
954            try:
955                i = self.barrier.wait()
956                if i == self.N//2:
957                    raise RuntimeError
958                self.barrier.wait()
959                results1.append(True)
960            except threading.BrokenBarrierError:
961                results2.append(True)
962            except RuntimeError:
963                self.barrier.abort()
964                pass
965            # Synchronize and reset the barrier.  Must synchronize first so
966            # that everyone has left it when we reset, and after so that no
967            # one enters it before the reset.
968            if barrier2.wait() == self.N//2:
969                self.barrier.reset()
970            barrier2.wait()
971            self.barrier.wait()
972            results3.append(True)
973
974        self.run_threads(f)
975        self.assertEqual(len(results1), 0)
976        self.assertEqual(len(results2), self.N-1)
977        self.assertEqual(len(results3), self.N)
978
979    def test_timeout(self):
980        """
981        Test wait(timeout)
982        """
983        def f():
984            i = self.barrier.wait()
985            if i == self.N // 2:
986                # One thread is late!
987                time.sleep(1.0)
988            # Default timeout is 2.0, so this is shorter.
989            self.assertRaises(threading.BrokenBarrierError,
990                              self.barrier.wait, 0.5)
991        self.run_threads(f)
992
993    def test_default_timeout(self):
994        """
995        Test the barrier's default timeout
996        """
997        # create a barrier with a low default timeout
998        barrier = self.barriertype(self.N, timeout=0.3)
999        def f():
1000            i = barrier.wait()
1001            if i == self.N // 2:
1002                # One thread is later than the default timeout of 0.3s.
1003                time.sleep(1.0)
1004            self.assertRaises(threading.BrokenBarrierError, barrier.wait)
1005        self.run_threads(f)
1006
1007    def test_single_thread(self):
1008        b = self.barriertype(1)
1009        b.wait()
1010        b.wait()
1011