1"""Synchronization primitives.""" 2 3__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 4 'BoundedSemaphore', 'Barrier') 5 6import collections 7import enum 8 9from . import exceptions 10from . import mixins 11 12class _ContextManagerMixin: 13 async def __aenter__(self): 14 await self.acquire() 15 # We have no use for the "as ..." clause in the with 16 # statement for locks. 17 return None 18 19 async def __aexit__(self, exc_type, exc, tb): 20 self.release() 21 22 23class Lock(_ContextManagerMixin, mixins._LoopBoundMixin): 24 """Primitive lock objects. 25 26 A primitive lock is a synchronization primitive that is not owned 27 by a particular task when locked. A primitive lock is in one 28 of two states, 'locked' or 'unlocked'. 29 30 It is created in the unlocked state. It has two basic methods, 31 acquire() and release(). When the state is unlocked, acquire() 32 changes the state to locked and returns immediately. When the 33 state is locked, acquire() blocks until a call to release() in 34 another task changes it to unlocked, then the acquire() call 35 resets it to locked and returns. The release() method should only 36 be called in the locked state; it changes the state to unlocked 37 and returns immediately. If an attempt is made to release an 38 unlocked lock, a RuntimeError will be raised. 39 40 When more than one task is blocked in acquire() waiting for 41 the state to turn to unlocked, only one task proceeds when a 42 release() call resets the state to unlocked; successive release() 43 calls will unblock tasks in FIFO order. 44 45 Locks also support the asynchronous context management protocol. 46 'async with lock' statement should be used. 47 48 Usage: 49 50 lock = Lock() 51 ... 52 await lock.acquire() 53 try: 54 ... 55 finally: 56 lock.release() 57 58 Context manager usage: 59 60 lock = Lock() 61 ... 62 async with lock: 63 ... 64 65 Lock objects can be tested for locking state: 66 67 if not lock.locked(): 68 await lock.acquire() 69 else: 70 # lock is acquired 71 ... 72 73 """ 74 75 def __init__(self): 76 self._waiters = None 77 self._locked = False 78 79 def __repr__(self): 80 res = super().__repr__() 81 extra = 'locked' if self._locked else 'unlocked' 82 if self._waiters: 83 extra = f'{extra}, waiters:{len(self._waiters)}' 84 return f'<{res[1:-1]} [{extra}]>' 85 86 def locked(self): 87 """Return True if lock is acquired.""" 88 return self._locked 89 90 async def acquire(self): 91 """Acquire a lock. 92 93 This method blocks until the lock is unlocked, then sets it to 94 locked and returns True. 95 """ 96 # Implement fair scheduling, where thread always waits 97 # its turn. Jumping the queue if all are cancelled is an optimization. 98 if (not self._locked and (self._waiters is None or 99 all(w.cancelled() for w in self._waiters))): 100 self._locked = True 101 return True 102 103 if self._waiters is None: 104 self._waiters = collections.deque() 105 fut = self._get_loop().create_future() 106 self._waiters.append(fut) 107 108 try: 109 try: 110 await fut 111 finally: 112 self._waiters.remove(fut) 113 except exceptions.CancelledError: 114 # Currently the only exception designed be able to occur here. 115 116 # Ensure the lock invariant: If lock is not claimed (or about 117 # to be claimed by us) and there is a Task in waiters, 118 # ensure that the Task at the head will run. 119 if not self._locked: 120 self._wake_up_first() 121 raise 122 123 # assert self._locked is False 124 self._locked = True 125 return True 126 127 def release(self): 128 """Release a lock. 129 130 When the lock is locked, reset it to unlocked, and return. 131 If any other tasks are blocked waiting for the lock to become 132 unlocked, allow exactly one of them to proceed. 133 134 When invoked on an unlocked lock, a RuntimeError is raised. 135 136 There is no return value. 137 """ 138 if self._locked: 139 self._locked = False 140 self._wake_up_first() 141 else: 142 raise RuntimeError('Lock is not acquired.') 143 144 def _wake_up_first(self): 145 """Ensure that the first waiter will wake up.""" 146 if not self._waiters: 147 return 148 try: 149 fut = next(iter(self._waiters)) 150 except StopIteration: 151 return 152 153 # .done() means that the waiter is already set to wake up. 154 if not fut.done(): 155 fut.set_result(True) 156 157 158class Event(mixins._LoopBoundMixin): 159 """Asynchronous equivalent to threading.Event. 160 161 Class implementing event objects. An event manages a flag that can be set 162 to true with the set() method and reset to false with the clear() method. 163 The wait() method blocks until the flag is true. The flag is initially 164 false. 165 """ 166 167 def __init__(self): 168 self._waiters = collections.deque() 169 self._value = False 170 171 def __repr__(self): 172 res = super().__repr__() 173 extra = 'set' if self._value else 'unset' 174 if self._waiters: 175 extra = f'{extra}, waiters:{len(self._waiters)}' 176 return f'<{res[1:-1]} [{extra}]>' 177 178 def is_set(self): 179 """Return True if and only if the internal flag is true.""" 180 return self._value 181 182 def set(self): 183 """Set the internal flag to true. All tasks waiting for it to 184 become true are awakened. Tasks that call wait() once the flag is 185 true will not block at all. 186 """ 187 if not self._value: 188 self._value = True 189 190 for fut in self._waiters: 191 if not fut.done(): 192 fut.set_result(True) 193 194 def clear(self): 195 """Reset the internal flag to false. Subsequently, tasks calling 196 wait() will block until set() is called to set the internal flag 197 to true again.""" 198 self._value = False 199 200 async def wait(self): 201 """Block until the internal flag is true. 202 203 If the internal flag is true on entry, return True 204 immediately. Otherwise, block until another task calls 205 set() to set the flag to true, then return True. 206 """ 207 if self._value: 208 return True 209 210 fut = self._get_loop().create_future() 211 self._waiters.append(fut) 212 try: 213 await fut 214 return True 215 finally: 216 self._waiters.remove(fut) 217 218 219class Condition(_ContextManagerMixin, mixins._LoopBoundMixin): 220 """Asynchronous equivalent to threading.Condition. 221 222 This class implements condition variable objects. A condition variable 223 allows one or more tasks to wait until they are notified by another 224 task. 225 226 A new Lock object is created and used as the underlying lock. 227 """ 228 229 def __init__(self, lock=None): 230 if lock is None: 231 lock = Lock() 232 233 self._lock = lock 234 # Export the lock's locked(), acquire() and release() methods. 235 self.locked = lock.locked 236 self.acquire = lock.acquire 237 self.release = lock.release 238 239 self._waiters = collections.deque() 240 241 def __repr__(self): 242 res = super().__repr__() 243 extra = 'locked' if self.locked() else 'unlocked' 244 if self._waiters: 245 extra = f'{extra}, waiters:{len(self._waiters)}' 246 return f'<{res[1:-1]} [{extra}]>' 247 248 async def wait(self): 249 """Wait until notified. 250 251 If the calling task has not acquired the lock when this 252 method is called, a RuntimeError is raised. 253 254 This method releases the underlying lock, and then blocks 255 until it is awakened by a notify() or notify_all() call for 256 the same condition variable in another task. Once 257 awakened, it re-acquires the lock and returns True. 258 259 This method may return spuriously, 260 which is why the caller should always 261 re-check the state and be prepared to wait() again. 262 """ 263 if not self.locked(): 264 raise RuntimeError('cannot wait on un-acquired lock') 265 266 fut = self._get_loop().create_future() 267 self.release() 268 try: 269 try: 270 self._waiters.append(fut) 271 try: 272 await fut 273 return True 274 finally: 275 self._waiters.remove(fut) 276 277 finally: 278 # Must re-acquire lock even if wait is cancelled. 279 # We only catch CancelledError here, since we don't want any 280 # other (fatal) errors with the future to cause us to spin. 281 err = None 282 while True: 283 try: 284 await self.acquire() 285 break 286 except exceptions.CancelledError as e: 287 err = e 288 289 if err is not None: 290 try: 291 raise err # Re-raise most recent exception instance. 292 finally: 293 err = None # Break reference cycles. 294 except BaseException: 295 # Any error raised out of here _may_ have occurred after this Task 296 # believed to have been successfully notified. 297 # Make sure to notify another Task instead. This may result 298 # in a "spurious wakeup", which is allowed as part of the 299 # Condition Variable protocol. 300 self._notify(1) 301 raise 302 303 async def wait_for(self, predicate): 304 """Wait until a predicate becomes true. 305 306 The predicate should be a callable whose result will be 307 interpreted as a boolean value. The method will repeatedly 308 wait() until it evaluates to true. The final predicate value is 309 the return value. 310 """ 311 result = predicate() 312 while not result: 313 await self.wait() 314 result = predicate() 315 return result 316 317 def notify(self, n=1): 318 """By default, wake up one task waiting on this condition, if any. 319 If the calling task has not acquired the lock when this method 320 is called, a RuntimeError is raised. 321 322 This method wakes up n of the tasks waiting for the condition 323 variable; if fewer than n are waiting, they are all awoken. 324 325 Note: an awakened task does not actually return from its 326 wait() call until it can reacquire the lock. Since notify() does 327 not release the lock, its caller should. 328 """ 329 if not self.locked(): 330 raise RuntimeError('cannot notify on un-acquired lock') 331 self._notify(n) 332 333 def _notify(self, n): 334 idx = 0 335 for fut in self._waiters: 336 if idx >= n: 337 break 338 339 if not fut.done(): 340 idx += 1 341 fut.set_result(False) 342 343 def notify_all(self): 344 """Wake up all threads waiting on this condition. This method acts 345 like notify(), but wakes up all waiting threads instead of one. If the 346 calling thread has not acquired the lock when this method is called, 347 a RuntimeError is raised. 348 """ 349 self.notify(len(self._waiters)) 350 351 352class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): 353 """A Semaphore implementation. 354 355 A semaphore manages an internal counter which is decremented by each 356 acquire() call and incremented by each release() call. The counter 357 can never go below zero; when acquire() finds that it is zero, it blocks, 358 waiting until some other thread calls release(). 359 360 Semaphores also support the context management protocol. 361 362 The optional argument gives the initial value for the internal 363 counter; it defaults to 1. If the value given is less than 0, 364 ValueError is raised. 365 """ 366 367 def __init__(self, value=1): 368 if value < 0: 369 raise ValueError("Semaphore initial value must be >= 0") 370 self._waiters = None 371 self._value = value 372 373 def __repr__(self): 374 res = super().__repr__() 375 extra = 'locked' if self.locked() else f'unlocked, value:{self._value}' 376 if self._waiters: 377 extra = f'{extra}, waiters:{len(self._waiters)}' 378 return f'<{res[1:-1]} [{extra}]>' 379 380 def locked(self): 381 """Returns True if semaphore cannot be acquired immediately.""" 382 # Due to state, or FIFO rules (must allow others to run first). 383 return self._value == 0 or ( 384 any(not w.cancelled() for w in (self._waiters or ()))) 385 386 async def acquire(self): 387 """Acquire a semaphore. 388 389 If the internal counter is larger than zero on entry, 390 decrement it by one and return True immediately. If it is 391 zero on entry, block, waiting until some other task has 392 called release() to make it larger than 0, and then return 393 True. 394 """ 395 if not self.locked(): 396 # Maintain FIFO, wait for others to start even if _value > 0. 397 self._value -= 1 398 return True 399 400 if self._waiters is None: 401 self._waiters = collections.deque() 402 fut = self._get_loop().create_future() 403 self._waiters.append(fut) 404 405 try: 406 try: 407 await fut 408 finally: 409 self._waiters.remove(fut) 410 except exceptions.CancelledError: 411 # Currently the only exception designed be able to occur here. 412 if fut.done() and not fut.cancelled(): 413 # Our Future was successfully set to True via _wake_up_next(), 414 # but we are not about to successfully acquire(). Therefore we 415 # must undo the bookkeeping already done and attempt to wake 416 # up someone else. 417 self._value += 1 418 raise 419 420 finally: 421 # New waiters may have arrived but had to wait due to FIFO. 422 # Wake up as many as are allowed. 423 while self._value > 0: 424 if not self._wake_up_next(): 425 break # There was no-one to wake up. 426 return True 427 428 def release(self): 429 """Release a semaphore, incrementing the internal counter by one. 430 431 When it was zero on entry and another task is waiting for it to 432 become larger than zero again, wake up that task. 433 """ 434 self._value += 1 435 self._wake_up_next() 436 437 def _wake_up_next(self): 438 """Wake up the first waiter that isn't done.""" 439 if not self._waiters: 440 return False 441 442 for fut in self._waiters: 443 if not fut.done(): 444 self._value -= 1 445 fut.set_result(True) 446 # `fut` is now `done()` and not `cancelled()`. 447 return True 448 return False 449 450 451class BoundedSemaphore(Semaphore): 452 """A bounded semaphore implementation. 453 454 This raises ValueError in release() if it would increase the value 455 above the initial value. 456 """ 457 458 def __init__(self, value=1): 459 self._bound_value = value 460 super().__init__(value) 461 462 def release(self): 463 if self._value >= self._bound_value: 464 raise ValueError('BoundedSemaphore released too many times') 465 super().release() 466 467 468 469class _BarrierState(enum.Enum): 470 FILLING = 'filling' 471 DRAINING = 'draining' 472 RESETTING = 'resetting' 473 BROKEN = 'broken' 474 475 476class Barrier(mixins._LoopBoundMixin): 477 """Asyncio equivalent to threading.Barrier 478 479 Implements a Barrier primitive. 480 Useful for synchronizing a fixed number of tasks at known synchronization 481 points. Tasks block on 'wait()' and are simultaneously awoken once they 482 have all made their call. 483 """ 484 485 def __init__(self, parties): 486 """Create a barrier, initialised to 'parties' tasks.""" 487 if parties < 1: 488 raise ValueError('parties must be > 0') 489 490 self._cond = Condition() # notify all tasks when state changes 491 492 self._parties = parties 493 self._state = _BarrierState.FILLING 494 self._count = 0 # count tasks in Barrier 495 496 def __repr__(self): 497 res = super().__repr__() 498 extra = f'{self._state.value}' 499 if not self.broken: 500 extra += f', waiters:{self.n_waiting}/{self.parties}' 501 return f'<{res[1:-1]} [{extra}]>' 502 503 async def __aenter__(self): 504 # wait for the barrier reaches the parties number 505 # when start draining release and return index of waited task 506 return await self.wait() 507 508 async def __aexit__(self, *args): 509 pass 510 511 async def wait(self): 512 """Wait for the barrier. 513 514 When the specified number of tasks have started waiting, they are all 515 simultaneously awoken. 516 Returns an unique and individual index number from 0 to 'parties-1'. 517 """ 518 async with self._cond: 519 await self._block() # Block while the barrier drains or resets. 520 try: 521 index = self._count 522 self._count += 1 523 if index + 1 == self._parties: 524 # We release the barrier 525 await self._release() 526 else: 527 await self._wait() 528 return index 529 finally: 530 self._count -= 1 531 # Wake up any tasks waiting for barrier to drain. 532 self._exit() 533 534 async def _block(self): 535 # Block until the barrier is ready for us, 536 # or raise an exception if it is broken. 537 # 538 # It is draining or resetting, wait until done 539 # unless a CancelledError occurs 540 await self._cond.wait_for( 541 lambda: self._state not in ( 542 _BarrierState.DRAINING, _BarrierState.RESETTING 543 ) 544 ) 545 546 # see if the barrier is in a broken state 547 if self._state is _BarrierState.BROKEN: 548 raise exceptions.BrokenBarrierError("Barrier aborted") 549 550 async def _release(self): 551 # Release the tasks waiting in the barrier. 552 553 # Enter draining state. 554 # Next waiting tasks will be blocked until the end of draining. 555 self._state = _BarrierState.DRAINING 556 self._cond.notify_all() 557 558 async def _wait(self): 559 # Wait in the barrier until we are released. Raise an exception 560 # if the barrier is reset or broken. 561 562 # wait for end of filling 563 # unless a CancelledError occurs 564 await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING) 565 566 if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING): 567 raise exceptions.BrokenBarrierError("Abort or reset of barrier") 568 569 def _exit(self): 570 # If we are the last tasks to exit the barrier, signal any tasks 571 # waiting for the barrier to drain. 572 if self._count == 0: 573 if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING): 574 self._state = _BarrierState.FILLING 575 self._cond.notify_all() 576 577 async def reset(self): 578 """Reset the barrier to the initial state. 579 580 Any tasks currently waiting will get the BrokenBarrier exception 581 raised. 582 """ 583 async with self._cond: 584 if self._count > 0: 585 if self._state is not _BarrierState.RESETTING: 586 #reset the barrier, waking up tasks 587 self._state = _BarrierState.RESETTING 588 else: 589 self._state = _BarrierState.FILLING 590 self._cond.notify_all() 591 592 async def abort(self): 593 """Place the barrier into a 'broken' state. 594 595 Useful in case of error. Any currently waiting tasks and tasks 596 attempting to 'wait()' will have BrokenBarrierError raised. 597 """ 598 async with self._cond: 599 self._state = _BarrierState.BROKEN 600 self._cond.notify_all() 601 602 @property 603 def parties(self): 604 """Return the number of tasks required to trip the barrier.""" 605 return self._parties 606 607 @property 608 def n_waiting(self): 609 """Return the number of tasks currently waiting at the barrier.""" 610 if self._state is _BarrierState.FILLING: 611 return self._count 612 return 0 613 614 @property 615 def broken(self): 616 """Return True if the barrier is in a broken state.""" 617 return self._state is _BarrierState.BROKEN 618