• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility to get tf.distribute.Strategy related contexts."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23from tensorflow.python.framework import ops
24from tensorflow.python.util.lazy_loader import LazyLoader
25from tensorflow.python.util.tf_export import tf_export
26
27
28# There is a circular dependency between this and the `distribute_lib` module.
29# So we load it lazily to work around this.
30distribute_lib = LazyLoader(
31    "distribute_lib", globals(),
32    "tensorflow.python.distribute.distribute_lib")
33
34# ------------------------------------------------------------------------------
35# Internal API for setting the current thread mode as being either in a
36# replica or cross-replica context for a particular tf.distribute.Strategy.
37
38
39class _ThreadMode(object):
40
41  def __init__(self, dist, cross, replica):
42    self.strategy = dist
43    self.cross_replica_context = cross
44    self.replica_context = replica
45
46
47class _CrossReplicaThreadMode(_ThreadMode):
48
49  def __init__(self, strategy):
50    _ThreadMode.__init__(self, strategy, strategy, None)
51
52
53class _InReplicaThreadMode(_ThreadMode):
54
55  def __init__(self, replica_ctx):
56    _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx)
57
58
59def _push_per_thread_mode(context):
60  ops.get_default_graph()._distribution_strategy_stack.append(context)  # pylint: disable=protected-access
61
62
63def _pop_per_thread_mode():
64  ops.get_default_graph()._distribution_strategy_stack.pop(-1)  # pylint: disable=protected-access
65
66
67class _DefaultReplicaThreadMode(_ThreadMode):
68  """Type of default value returned by `_get_per_thread_mode()`.
69
70  Used when the thread-local stack is empty.
71  """
72
73  def __init__(self):
74    _ThreadMode.__init__(self, _get_default_strategy(), None,
75                         _get_default_replica_context())
76
77
78def _get_per_thread_mode():
79  try:
80    return ops.get_default_graph()._distribution_strategy_stack[-1]  # pylint: disable=protected-access
81  except (AttributeError, IndexError):
82    return _get_default_replica_mode()
83
84
85# ------------------------------------------------------------------------------
86# Public API for accessing the current thread mode
87
88
89@tf_export("distribute.get_replica_context")
90def get_replica_context():
91  """Returns the current `tf.distribute.ReplicaContext` or `None`.
92
93  Returns `None` if in a cross-replica context.
94
95  Note that execution:
96
97  1. starts in the default (single-replica) replica context (this function
98     will return the default `ReplicaContext` object);
99  2. switches to cross-replica context (in which case this will return
100     `None`) when entering a `with tf.distribute.Strategy.scope():` block;
101  3. switches to a (non-default) replica context inside
102     `strategy.experimental_run_v2(fn, ...)`;
103  4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
104     inside `merge_fn` you are back in the cross-replica context (and again
105     this function will return `None`).
106
107  Most `tf.distribute.Strategy` methods may only be executed in
108  a cross-replica context, in a replica context you should use the
109  API of the `tf.distribute.ReplicaContext` object returned by this
110  method instead.
111
112  ```
113  assert tf.distribute.get_replica_context() is not None  # default
114  with strategy.scope():
115    assert tf.distribute.get_replica_context() is None
116
117    def f():
118      replica_context = tf.distribute.get_replica_context()  # for strategy
119      assert replica_context is not None
120      tf.print("Replica id: ", replica_context.replica_id_in_sync_group,
121               " of ", replica_context.num_replicas_in_sync)
122
123    strategy.experimental_run_v2(f)
124  ```
125
126  Returns:
127    The current `tf.distribute.ReplicaContext` object when in a replica context
128    scope, else `None`.
129
130    Within a particular block, exactly one of these two things will be true:
131
132    * `get_replica_context()` returns non-`None`, or
133    * `tf.distribute.is_cross_replica_context()` returns True.
134  """
135  return _get_per_thread_mode().replica_context
136
137
138def get_cross_replica_context():
139  """Returns the current tf.distribute.Strategy if in a cross-replica context.
140
141  DEPRECATED: Please use `in_cross_replica_context()` and
142  `get_strategy()` instead.
143
144  Returns:
145    Returns the current `tf.distribute.Strategy` object in a cross-replica
146    context, or `None`.
147
148    Exactly one of `get_replica_context()` and `get_cross_replica_context()`
149    will return `None` in a particular block.
150  """
151  return _get_per_thread_mode().cross_replica_context
152
153
154@tf_export("distribute.in_cross_replica_context")
155def in_cross_replica_context():
156  """Returns `True` if in a cross-replica context.
157
158  See `tf.distribute.get_replica_context` for details.
159
160  ```
161  assert not tf.distribute.in_cross_replica_context()
162  with strategy.scope():
163    assert tf.distribute.in_cross_replica_context()
164
165    def f():
166      assert not tf.distribute.in_cross_replica_context()
167
168    strategy.experimental_run_v2(f)
169  ```
170
171  Returns:
172    `True` if in a cross-replica context (`get_replica_context()` returns
173    `None`), or `False` if in a replica context (`get_replica_context()` returns
174    non-`None`).
175  """
176  return _get_per_thread_mode().cross_replica_context is not None
177
178
179@tf_export("distribute.get_strategy")
180def get_strategy():
181  """Returns the current `tf.distribute.Strategy` object.
182
183  Typically only used in a cross-replica context:
184
185  ```
186  if tf.distribute.in_cross_replica_context():
187    strategy = tf.distribute.get_strategy()
188    ...
189  ```
190
191  Returns:
192    A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block,
193    it returns `strategy`, otherwise it returns the default (single-replica)
194    `tf.distribute.Strategy` object.
195  """
196  return _get_per_thread_mode().strategy
197
198
199@tf_export("distribute.has_strategy")
200def has_strategy():
201  """Return if there is a current non-default `tf.distribute.Strategy`.
202
203  ```
204  assert not tf.distribute.has_strategy()
205  with strategy.scope():
206    assert tf.distribute.has_strategy()
207  ```
208
209  Returns:
210    True if inside a `with strategy.scope():`.
211  """
212  return get_strategy() is not _get_default_strategy()
213
214
215def get_strategy_and_replica_context():
216  per_thread_mode = _get_per_thread_mode()
217  return (per_thread_mode.strategy, per_thread_mode.replica_context)
218
219
220@tf_export("distribute.experimental_set_strategy")
221def experimental_set_strategy(strategy):
222  """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`.
223
224  ```
225  tf.distribute.experimental_set_strategy(strategy1)
226  f()
227  tf.distribute.experimental_set_strategy(strategy2)
228  g()
229  tf.distribute.experimental_set_strategy(None)
230  h()
231  ```
232
233  is equivalent to:
234
235  ```
236  with strategy1.scope():
237    f()
238  with strategy2.scope():
239    g()
240  h()
241  ```
242
243  In general, you should use the `with strategy.scope():` API, but this
244  alternative may be convenient in notebooks where you would have to put
245  each cell in a `with strategy.scope():` block.
246
247  Note: This should only be called outside of any TensorFlow scope to
248  avoid improper nesting.
249
250  Args:
251    strategy: A `tf.distribute.Strategy` object or None.
252
253  Raises:
254    RuntimeError: If called inside a `with strategy.scope():`.
255  """
256  old_scope = ops.get_default_graph()._global_distribute_strategy_scope  # pylint: disable=protected-access
257  if old_scope is not None:
258    old_scope.__exit__(None, None, None)
259    ops.get_default_graph()._global_distribute_strategy_scope = None  # pylint: disable=protected-access
260  if has_strategy():
261    raise RuntimeError(
262        "Must not be called inside a `tf.distribute.Strategy` scope.")
263  if strategy is not None:
264    new_scope = strategy.scope()
265    new_scope.__enter__()
266    ops.get_default_graph()._global_distribute_strategy_scope = new_scope  # pylint: disable=protected-access
267
268
269# ------------------------------------------------------------------------------
270# Defaults that are used when no tf.distribute.Strategy is explicitly created.
271# We create them lazily in a function so that we can workaround the circular
272# dependency on distribute_lib. See lazy loader at the top of this file.
273
274_defaults = {
275    "strategy": None,
276    "replica_context": None,
277    "replica_mode": None
278}
279# Note: These need to be different locks since _get_default_replica_context
280# calls _get_default_strategy inside its lock, and them using the same lock
281# can lead to deadlock.
282_default_strategy_lock = threading.Lock()
283_default_replica_context_lock = threading.Lock()
284_default_replica_mode_lock = threading.Lock()
285
286
287def _get_default_strategy():
288  if _defaults["strategy"] is None:
289    # Avoid race condition causing two defaults to be created
290    with _default_strategy_lock:
291      if _defaults["strategy"] is None:
292        # pylint: disable=protected-access
293        # Make sure distribute_lib module is loaded by accessing some member.
294        _ = distribute_lib._creating_default_strategy_singleton
295        distribute_lib._creating_default_strategy_singleton = True
296        _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy()
297        distribute_lib._creating_default_strategy_singleton = False
298        # pylint: enable=protected-access
299  return _defaults["strategy"]
300
301
302def _get_default_replica_context():
303  if _defaults["replica_context"] is None:
304    # Avoid race condition causing two defaults to be created
305    with _default_replica_context_lock:
306      if _defaults["replica_context"] is None:
307        _defaults["replica_context"] = distribute_lib.ReplicaContext(
308            _get_default_strategy(), replica_id_in_sync_group=0)
309  return _defaults["replica_context"]
310
311
312def _get_default_replica_mode():
313  if _defaults["replica_mode"] is None:
314    # Avoid race condition causing two defaults to be created
315    with _default_replica_mode_lock:
316      if _defaults["replica_mode"] is None:
317        _defaults["replica_mode"] = _DefaultReplicaThreadMode()
318  return _defaults["replica_mode"]
319
320
321# Aliases for compatibility with old names.
322get_distribution_strategy = get_strategy
323has_distribution_strategy = has_strategy
324