• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Locking related utils."""
16
17import threading
18
19
20class GroupLock(object):
21  """A lock to allow many members of a group to access a resource exclusively.
22
23  This lock provides a way to allow access to a resource by multiple threads
24  belonging to a logical group at the same time, while restricting access to
25  threads from all other groups. You can think of this as an extension of a
26  reader-writer lock, where you allow multiple writers at the same time. We
27  made it generic to support multiple groups instead of just two - readers and
28  writers.
29
30  Simple usage example with two groups accessing the same resource:
31
32  ```python
33  lock = GroupLock(num_groups=2)
34
35  # In a member of group 0:
36  with lock.group(0):
37    # do stuff, access the resource
38    # ...
39
40  # In a member of group 1:
41  with lock.group(1):
42    # do stuff, access the resource
43    # ...
44  ```
45
46  Using as a context manager with `.group(group_id)` is the easiest way. You
47  can also use the `acquire` and `release` method directly.
48  """
49
50  __slots__ = ["_ready", "_num_groups", "_group_member_counts"]
51
52  def __init__(self, num_groups=2):
53    """Initialize a group lock.
54
55    Args:
56      num_groups: The number of groups that will be accessing the resource under
57        consideration. Should be a positive number.
58
59    Returns:
60      A group lock that can then be used to synchronize code.
61
62    Raises:
63      ValueError: If num_groups is less than 1.
64    """
65    if num_groups < 1:
66      raise ValueError(
67          "Argument `num_groups` must be a positive integer. "
68          f"Received: num_groups={num_groups}")
69    self._ready = threading.Condition(threading.Lock())
70    self._num_groups = num_groups
71    self._group_member_counts = [0] * self._num_groups
72
73  def group(self, group_id):
74    """Enter a context where the lock is with group `group_id`.
75
76    Args:
77      group_id: The group for which to acquire and release the lock.
78
79    Returns:
80      A context manager which will acquire the lock for `group_id`.
81    """
82    self._validate_group_id(group_id)
83    return self._Context(self, group_id)
84
85  def acquire(self, group_id):
86    """Acquire the group lock for a specific group `group_id`."""
87    self._validate_group_id(group_id)
88
89    self._ready.acquire()
90    while self._another_group_active(group_id):
91      self._ready.wait()
92    self._group_member_counts[group_id] += 1
93    self._ready.release()
94
95  def release(self, group_id):
96    """Release the group lock for a specific group `group_id`."""
97    self._validate_group_id(group_id)
98
99    self._ready.acquire()
100    self._group_member_counts[group_id] -= 1
101    if self._group_member_counts[group_id] == 0:
102      self._ready.notify_all()
103    self._ready.release()
104
105  def _another_group_active(self, group_id):
106    return any(
107        c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id)
108
109  def _validate_group_id(self, group_id):
110    if group_id < 0 or group_id >= self._num_groups:
111      raise ValueError(
112          "Argument `group_id` should verify `0 <= group_id < num_groups` "
113          f"(with `num_groups={self._num_groups}`). "
114          f"Received: group_id={group_id}")
115
116  class _Context(object):
117    """Context manager helper for `GroupLock`."""
118
119    __slots__ = ["_lock", "_group_id"]
120
121    def __init__(self, lock, group_id):
122      self._lock = lock
123      self._group_id = group_id
124
125    def __enter__(self):
126      self._lock.acquire(self._group_id)
127
128    def __exit__(self, type_arg, value_arg, traceback_arg):
129      del type_arg, value_arg, traceback_arg
130      self._lock.release(self._group_id)
131