1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Locking related utils.""" 16 17import threading 18 19 20class GroupLock(object): 21 """A lock to allow many members of a group to access a resource exclusively. 22 23 This lock provides a way to allow access to a resource by multiple threads 24 belonging to a logical group at the same time, while restricting access to 25 threads from all other groups. You can think of this as an extension of a 26 reader-writer lock, where you allow multiple writers at the same time. We 27 made it generic to support multiple groups instead of just two - readers and 28 writers. 29 30 Simple usage example with two groups accessing the same resource: 31 32 ```python 33 lock = GroupLock(num_groups=2) 34 35 # In a member of group 0: 36 with lock.group(0): 37 # do stuff, access the resource 38 # ... 39 40 # In a member of group 1: 41 with lock.group(1): 42 # do stuff, access the resource 43 # ... 44 ``` 45 46 Using as a context manager with `.group(group_id)` is the easiest way. You 47 can also use the `acquire` and `release` method directly. 48 """ 49 50 __slots__ = ["_ready", "_num_groups", "_group_member_counts"] 51 52 def __init__(self, num_groups=2): 53 """Initialize a group lock. 54 55 Args: 56 num_groups: The number of groups that will be accessing the resource under 57 consideration. Should be a positive number. 58 59 Returns: 60 A group lock that can then be used to synchronize code. 61 62 Raises: 63 ValueError: If num_groups is less than 1. 64 """ 65 if num_groups < 1: 66 raise ValueError( 67 "Argument `num_groups` must be a positive integer. " 68 f"Received: num_groups={num_groups}") 69 self._ready = threading.Condition(threading.Lock()) 70 self._num_groups = num_groups 71 self._group_member_counts = [0] * self._num_groups 72 73 def group(self, group_id): 74 """Enter a context where the lock is with group `group_id`. 75 76 Args: 77 group_id: The group for which to acquire and release the lock. 78 79 Returns: 80 A context manager which will acquire the lock for `group_id`. 81 """ 82 self._validate_group_id(group_id) 83 return self._Context(self, group_id) 84 85 def acquire(self, group_id): 86 """Acquire the group lock for a specific group `group_id`.""" 87 self._validate_group_id(group_id) 88 89 self._ready.acquire() 90 while self._another_group_active(group_id): 91 self._ready.wait() 92 self._group_member_counts[group_id] += 1 93 self._ready.release() 94 95 def release(self, group_id): 96 """Release the group lock for a specific group `group_id`.""" 97 self._validate_group_id(group_id) 98 99 self._ready.acquire() 100 self._group_member_counts[group_id] -= 1 101 if self._group_member_counts[group_id] == 0: 102 self._ready.notify_all() 103 self._ready.release() 104 105 def _another_group_active(self, group_id): 106 return any( 107 c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id) 108 109 def _validate_group_id(self, group_id): 110 if group_id < 0 or group_id >= self._num_groups: 111 raise ValueError( 112 "Argument `group_id` should verify `0 <= group_id < num_groups` " 113 f"(with `num_groups={self._num_groups}`). " 114 f"Received: group_id={group_id}") 115 116 class _Context(object): 117 """Context manager helper for `GroupLock`.""" 118 119 __slots__ = ["_lock", "_group_id"] 120 121 def __init__(self, lock, group_id): 122 self._lock = lock 123 self._group_id = group_id 124 125 def __enter__(self): 126 self._lock.acquire(self._group_id) 127 128 def __exit__(self, type_arg, value_arg, traceback_arg): 129 del type_arg, value_arg, traceback_arg 130 self._lock.release(self._group_id) 131