• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Synchronization primitives."""
2
3__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
4           'BoundedSemaphore', 'Barrier')
5
6import collections
7import enum
8
9from . import exceptions
10from . import mixins
11
12class _ContextManagerMixin:
13    async def __aenter__(self):
14        await self.acquire()
15        # We have no use for the "as ..."  clause in the with
16        # statement for locks.
17        return None
18
19    async def __aexit__(self, exc_type, exc, tb):
20        self.release()
21
22
23class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
24    """Primitive lock objects.
25
26    A primitive lock is a synchronization primitive that is not owned
27    by a particular task when locked.  A primitive lock is in one
28    of two states, 'locked' or 'unlocked'.
29
30    It is created in the unlocked state.  It has two basic methods,
31    acquire() and release().  When the state is unlocked, acquire()
32    changes the state to locked and returns immediately.  When the
33    state is locked, acquire() blocks until a call to release() in
34    another task changes it to unlocked, then the acquire() call
35    resets it to locked and returns.  The release() method should only
36    be called in the locked state; it changes the state to unlocked
37    and returns immediately.  If an attempt is made to release an
38    unlocked lock, a RuntimeError will be raised.
39
40    When more than one task is blocked in acquire() waiting for
41    the state to turn to unlocked, only one task proceeds when a
42    release() call resets the state to unlocked; successive release()
43    calls will unblock tasks in FIFO order.
44
45    Locks also support the asynchronous context management protocol.
46    'async with lock' statement should be used.
47
48    Usage:
49
50        lock = Lock()
51        ...
52        await lock.acquire()
53        try:
54            ...
55        finally:
56            lock.release()
57
58    Context manager usage:
59
60        lock = Lock()
61        ...
62        async with lock:
63             ...
64
65    Lock objects can be tested for locking state:
66
67        if not lock.locked():
68           await lock.acquire()
69        else:
70           # lock is acquired
71           ...
72
73    """
74
75    def __init__(self):
76        self._waiters = None
77        self._locked = False
78
79    def __repr__(self):
80        res = super().__repr__()
81        extra = 'locked' if self._locked else 'unlocked'
82        if self._waiters:
83            extra = f'{extra}, waiters:{len(self._waiters)}'
84        return f'<{res[1:-1]} [{extra}]>'
85
86    def locked(self):
87        """Return True if lock is acquired."""
88        return self._locked
89
90    async def acquire(self):
91        """Acquire a lock.
92
93        This method blocks until the lock is unlocked, then sets it to
94        locked and returns True.
95        """
96        # Implement fair scheduling, where thread always waits
97        # its turn. Jumping the queue if all are cancelled is an optimization.
98        if (not self._locked and (self._waiters is None or
99                all(w.cancelled() for w in self._waiters))):
100            self._locked = True
101            return True
102
103        if self._waiters is None:
104            self._waiters = collections.deque()
105        fut = self._get_loop().create_future()
106        self._waiters.append(fut)
107
108        try:
109            try:
110                await fut
111            finally:
112                self._waiters.remove(fut)
113        except exceptions.CancelledError:
114            # Currently the only exception designed be able to occur here.
115
116            # Ensure the lock invariant: If lock is not claimed (or about
117            # to be claimed by us) and there is a Task in waiters,
118            # ensure that the Task at the head will run.
119            if not self._locked:
120                self._wake_up_first()
121            raise
122
123        # assert self._locked is False
124        self._locked = True
125        return True
126
127    def release(self):
128        """Release a lock.
129
130        When the lock is locked, reset it to unlocked, and return.
131        If any other tasks are blocked waiting for the lock to become
132        unlocked, allow exactly one of them to proceed.
133
134        When invoked on an unlocked lock, a RuntimeError is raised.
135
136        There is no return value.
137        """
138        if self._locked:
139            self._locked = False
140            self._wake_up_first()
141        else:
142            raise RuntimeError('Lock is not acquired.')
143
144    def _wake_up_first(self):
145        """Ensure that the first waiter will wake up."""
146        if not self._waiters:
147            return
148        try:
149            fut = next(iter(self._waiters))
150        except StopIteration:
151            return
152
153        # .done() means that the waiter is already set to wake up.
154        if not fut.done():
155            fut.set_result(True)
156
157
158class Event(mixins._LoopBoundMixin):
159    """Asynchronous equivalent to threading.Event.
160
161    Class implementing event objects. An event manages a flag that can be set
162    to true with the set() method and reset to false with the clear() method.
163    The wait() method blocks until the flag is true. The flag is initially
164    false.
165    """
166
167    def __init__(self):
168        self._waiters = collections.deque()
169        self._value = False
170
171    def __repr__(self):
172        res = super().__repr__()
173        extra = 'set' if self._value else 'unset'
174        if self._waiters:
175            extra = f'{extra}, waiters:{len(self._waiters)}'
176        return f'<{res[1:-1]} [{extra}]>'
177
178    def is_set(self):
179        """Return True if and only if the internal flag is true."""
180        return self._value
181
182    def set(self):
183        """Set the internal flag to true. All tasks waiting for it to
184        become true are awakened. Tasks that call wait() once the flag is
185        true will not block at all.
186        """
187        if not self._value:
188            self._value = True
189
190            for fut in self._waiters:
191                if not fut.done():
192                    fut.set_result(True)
193
194    def clear(self):
195        """Reset the internal flag to false. Subsequently, tasks calling
196        wait() will block until set() is called to set the internal flag
197        to true again."""
198        self._value = False
199
200    async def wait(self):
201        """Block until the internal flag is true.
202
203        If the internal flag is true on entry, return True
204        immediately.  Otherwise, block until another task calls
205        set() to set the flag to true, then return True.
206        """
207        if self._value:
208            return True
209
210        fut = self._get_loop().create_future()
211        self._waiters.append(fut)
212        try:
213            await fut
214            return True
215        finally:
216            self._waiters.remove(fut)
217
218
219class Condition(_ContextManagerMixin, mixins._LoopBoundMixin):
220    """Asynchronous equivalent to threading.Condition.
221
222    This class implements condition variable objects. A condition variable
223    allows one or more tasks to wait until they are notified by another
224    task.
225
226    A new Lock object is created and used as the underlying lock.
227    """
228
229    def __init__(self, lock=None):
230        if lock is None:
231            lock = Lock()
232
233        self._lock = lock
234        # Export the lock's locked(), acquire() and release() methods.
235        self.locked = lock.locked
236        self.acquire = lock.acquire
237        self.release = lock.release
238
239        self._waiters = collections.deque()
240
241    def __repr__(self):
242        res = super().__repr__()
243        extra = 'locked' if self.locked() else 'unlocked'
244        if self._waiters:
245            extra = f'{extra}, waiters:{len(self._waiters)}'
246        return f'<{res[1:-1]} [{extra}]>'
247
248    async def wait(self):
249        """Wait until notified.
250
251        If the calling task has not acquired the lock when this
252        method is called, a RuntimeError is raised.
253
254        This method releases the underlying lock, and then blocks
255        until it is awakened by a notify() or notify_all() call for
256        the same condition variable in another task.  Once
257        awakened, it re-acquires the lock and returns True.
258
259        This method may return spuriously,
260        which is why the caller should always
261        re-check the state and be prepared to wait() again.
262        """
263        if not self.locked():
264            raise RuntimeError('cannot wait on un-acquired lock')
265
266        fut = self._get_loop().create_future()
267        self.release()
268        try:
269            try:
270                self._waiters.append(fut)
271                try:
272                    await fut
273                    return True
274                finally:
275                    self._waiters.remove(fut)
276
277            finally:
278                # Must re-acquire lock even if wait is cancelled.
279                # We only catch CancelledError here, since we don't want any
280                # other (fatal) errors with the future to cause us to spin.
281                err = None
282                while True:
283                    try:
284                        await self.acquire()
285                        break
286                    except exceptions.CancelledError as e:
287                        err = e
288
289                if err is not None:
290                    try:
291                        raise err  # Re-raise most recent exception instance.
292                    finally:
293                        err = None  # Break reference cycles.
294        except BaseException:
295            # Any error raised out of here _may_ have occurred after this Task
296            # believed to have been successfully notified.
297            # Make sure to notify another Task instead.  This may result
298            # in a "spurious wakeup", which is allowed as part of the
299            # Condition Variable protocol.
300            self._notify(1)
301            raise
302
303    async def wait_for(self, predicate):
304        """Wait until a predicate becomes true.
305
306        The predicate should be a callable whose result will be
307        interpreted as a boolean value.  The method will repeatedly
308        wait() until it evaluates to true.  The final predicate value is
309        the return value.
310        """
311        result = predicate()
312        while not result:
313            await self.wait()
314            result = predicate()
315        return result
316
317    def notify(self, n=1):
318        """By default, wake up one task waiting on this condition, if any.
319        If the calling task has not acquired the lock when this method
320        is called, a RuntimeError is raised.
321
322        This method wakes up n of the tasks waiting for the condition
323         variable; if fewer than n are waiting, they are all awoken.
324
325        Note: an awakened task does not actually return from its
326        wait() call until it can reacquire the lock. Since notify() does
327        not release the lock, its caller should.
328        """
329        if not self.locked():
330            raise RuntimeError('cannot notify on un-acquired lock')
331        self._notify(n)
332
333    def _notify(self, n):
334        idx = 0
335        for fut in self._waiters:
336            if idx >= n:
337                break
338
339            if not fut.done():
340                idx += 1
341                fut.set_result(False)
342
343    def notify_all(self):
344        """Wake up all threads waiting on this condition. This method acts
345        like notify(), but wakes up all waiting threads instead of one. If the
346        calling thread has not acquired the lock when this method is called,
347        a RuntimeError is raised.
348        """
349        self.notify(len(self._waiters))
350
351
352class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
353    """A Semaphore implementation.
354
355    A semaphore manages an internal counter which is decremented by each
356    acquire() call and incremented by each release() call. The counter
357    can never go below zero; when acquire() finds that it is zero, it blocks,
358    waiting until some other thread calls release().
359
360    Semaphores also support the context management protocol.
361
362    The optional argument gives the initial value for the internal
363    counter; it defaults to 1. If the value given is less than 0,
364    ValueError is raised.
365    """
366
367    def __init__(self, value=1):
368        if value < 0:
369            raise ValueError("Semaphore initial value must be >= 0")
370        self._waiters = None
371        self._value = value
372
373    def __repr__(self):
374        res = super().__repr__()
375        extra = 'locked' if self.locked() else f'unlocked, value:{self._value}'
376        if self._waiters:
377            extra = f'{extra}, waiters:{len(self._waiters)}'
378        return f'<{res[1:-1]} [{extra}]>'
379
380    def locked(self):
381        """Returns True if semaphore cannot be acquired immediately."""
382        # Due to state, or FIFO rules (must allow others to run first).
383        return self._value == 0 or (
384            any(not w.cancelled() for w in (self._waiters or ())))
385
386    async def acquire(self):
387        """Acquire a semaphore.
388
389        If the internal counter is larger than zero on entry,
390        decrement it by one and return True immediately.  If it is
391        zero on entry, block, waiting until some other task has
392        called release() to make it larger than 0, and then return
393        True.
394        """
395        if not self.locked():
396            # Maintain FIFO, wait for others to start even if _value > 0.
397            self._value -= 1
398            return True
399
400        if self._waiters is None:
401            self._waiters = collections.deque()
402        fut = self._get_loop().create_future()
403        self._waiters.append(fut)
404
405        try:
406            try:
407                await fut
408            finally:
409                self._waiters.remove(fut)
410        except exceptions.CancelledError:
411            # Currently the only exception designed be able to occur here.
412            if fut.done() and not fut.cancelled():
413                # Our Future was successfully set to True via _wake_up_next(),
414                # but we are not about to successfully acquire(). Therefore we
415                # must undo the bookkeeping already done and attempt to wake
416                # up someone else.
417                self._value += 1
418            raise
419
420        finally:
421            # New waiters may have arrived but had to wait due to FIFO.
422            # Wake up as many as are allowed.
423            while self._value > 0:
424                if not self._wake_up_next():
425                    break  # There was no-one to wake up.
426        return True
427
428    def release(self):
429        """Release a semaphore, incrementing the internal counter by one.
430
431        When it was zero on entry and another task is waiting for it to
432        become larger than zero again, wake up that task.
433        """
434        self._value += 1
435        self._wake_up_next()
436
437    def _wake_up_next(self):
438        """Wake up the first waiter that isn't done."""
439        if not self._waiters:
440            return False
441
442        for fut in self._waiters:
443            if not fut.done():
444                self._value -= 1
445                fut.set_result(True)
446                # `fut` is now `done()` and not `cancelled()`.
447                return True
448        return False
449
450
451class BoundedSemaphore(Semaphore):
452    """A bounded semaphore implementation.
453
454    This raises ValueError in release() if it would increase the value
455    above the initial value.
456    """
457
458    def __init__(self, value=1):
459        self._bound_value = value
460        super().__init__(value)
461
462    def release(self):
463        if self._value >= self._bound_value:
464            raise ValueError('BoundedSemaphore released too many times')
465        super().release()
466
467
468
469class _BarrierState(enum.Enum):
470    FILLING = 'filling'
471    DRAINING = 'draining'
472    RESETTING = 'resetting'
473    BROKEN = 'broken'
474
475
476class Barrier(mixins._LoopBoundMixin):
477    """Asyncio equivalent to threading.Barrier
478
479    Implements a Barrier primitive.
480    Useful for synchronizing a fixed number of tasks at known synchronization
481    points. Tasks block on 'wait()' and are simultaneously awoken once they
482    have all made their call.
483    """
484
485    def __init__(self, parties):
486        """Create a barrier, initialised to 'parties' tasks."""
487        if parties < 1:
488            raise ValueError('parties must be > 0')
489
490        self._cond = Condition() # notify all tasks when state changes
491
492        self._parties = parties
493        self._state = _BarrierState.FILLING
494        self._count = 0       # count tasks in Barrier
495
496    def __repr__(self):
497        res = super().__repr__()
498        extra = f'{self._state.value}'
499        if not self.broken:
500            extra += f', waiters:{self.n_waiting}/{self.parties}'
501        return f'<{res[1:-1]} [{extra}]>'
502
503    async def __aenter__(self):
504        # wait for the barrier reaches the parties number
505        # when start draining release and return index of waited task
506        return await self.wait()
507
508    async def __aexit__(self, *args):
509        pass
510
511    async def wait(self):
512        """Wait for the barrier.
513
514        When the specified number of tasks have started waiting, they are all
515        simultaneously awoken.
516        Returns an unique and individual index number from 0 to 'parties-1'.
517        """
518        async with self._cond:
519            await self._block() # Block while the barrier drains or resets.
520            try:
521                index = self._count
522                self._count += 1
523                if index + 1 == self._parties:
524                    # We release the barrier
525                    await self._release()
526                else:
527                    await self._wait()
528                return index
529            finally:
530                self._count -= 1
531                # Wake up any tasks waiting for barrier to drain.
532                self._exit()
533
534    async def _block(self):
535        # Block until the barrier is ready for us,
536        # or raise an exception if it is broken.
537        #
538        # It is draining or resetting, wait until done
539        # unless a CancelledError occurs
540        await self._cond.wait_for(
541            lambda: self._state not in (
542                _BarrierState.DRAINING, _BarrierState.RESETTING
543            )
544        )
545
546        # see if the barrier is in a broken state
547        if self._state is _BarrierState.BROKEN:
548            raise exceptions.BrokenBarrierError("Barrier aborted")
549
550    async def _release(self):
551        # Release the tasks waiting in the barrier.
552
553        # Enter draining state.
554        # Next waiting tasks will be blocked until the end of draining.
555        self._state = _BarrierState.DRAINING
556        self._cond.notify_all()
557
558    async def _wait(self):
559        # Wait in the barrier until we are released. Raise an exception
560        # if the barrier is reset or broken.
561
562        # wait for end of filling
563        # unless a CancelledError occurs
564        await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
565
566        if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
567            raise exceptions.BrokenBarrierError("Abort or reset of barrier")
568
569    def _exit(self):
570        # If we are the last tasks to exit the barrier, signal any tasks
571        # waiting for the barrier to drain.
572        if self._count == 0:
573            if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
574                self._state = _BarrierState.FILLING
575            self._cond.notify_all()
576
577    async def reset(self):
578        """Reset the barrier to the initial state.
579
580        Any tasks currently waiting will get the BrokenBarrier exception
581        raised.
582        """
583        async with self._cond:
584            if self._count > 0:
585                if self._state is not _BarrierState.RESETTING:
586                    #reset the barrier, waking up tasks
587                    self._state = _BarrierState.RESETTING
588            else:
589                self._state = _BarrierState.FILLING
590            self._cond.notify_all()
591
592    async def abort(self):
593        """Place the barrier into a 'broken' state.
594
595        Useful in case of error.  Any currently waiting tasks and tasks
596        attempting to 'wait()' will have BrokenBarrierError raised.
597        """
598        async with self._cond:
599            self._state = _BarrierState.BROKEN
600            self._cond.notify_all()
601
602    @property
603    def parties(self):
604        """Return the number of tasks required to trip the barrier."""
605        return self._parties
606
607    @property
608    def n_waiting(self):
609        """Return the number of tasks currently waiting at the barrier."""
610        if self._state is _BarrierState.FILLING:
611            return self._count
612        return 0
613
614    @property
615    def broken(self):
616        """Return True if the barrier is in a broken state."""
617        return self._state is _BarrierState.BROKEN
618