1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to get tf.distribute.Strategy related contexts.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22 23from tensorflow.python.framework import ops 24from tensorflow.python.util.lazy_loader import LazyLoader 25from tensorflow.python.util.tf_export import tf_export 26 27 28# There is a circular dependency between this and the `distribute_lib` module. 29# So we load it lazily to work around this. 30distribute_lib = LazyLoader( 31 "distribute_lib", globals(), 32 "tensorflow.python.distribute.distribute_lib") 33 34# ------------------------------------------------------------------------------ 35# Internal API for setting the current thread mode as being either in a 36# replica or cross-replica context for a particular tf.distribute.Strategy. 37 38 39class _ThreadMode(object): 40 41 def __init__(self, dist, cross, replica): 42 self.strategy = dist 43 self.cross_replica_context = cross 44 self.replica_context = replica 45 46 47class _CrossReplicaThreadMode(_ThreadMode): 48 49 def __init__(self, strategy): 50 _ThreadMode.__init__(self, strategy, strategy, None) 51 52 53class _InReplicaThreadMode(_ThreadMode): 54 55 def __init__(self, replica_ctx): 56 _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx) 57 58 59def _push_per_thread_mode(context): 60 ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access 61 62 63def _pop_per_thread_mode(): 64 ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access 65 66 67class _DefaultReplicaThreadMode(_ThreadMode): 68 """Type of default value returned by `_get_per_thread_mode()`. 69 70 Used when the thread-local stack is empty. 71 """ 72 73 def __init__(self): 74 _ThreadMode.__init__(self, _get_default_strategy(), None, 75 _get_default_replica_context()) 76 77 78def _get_per_thread_mode(): 79 try: 80 return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access 81 except (AttributeError, IndexError): 82 return _get_default_replica_mode() 83 84 85# ------------------------------------------------------------------------------ 86# Public API for accessing the current thread mode 87 88 89@tf_export("distribute.get_replica_context") 90def get_replica_context(): 91 """Returns the current `tf.distribute.ReplicaContext` or `None`. 92 93 Returns `None` if in a cross-replica context. 94 95 Note that execution: 96 97 1. starts in the default (single-replica) replica context (this function 98 will return the default `ReplicaContext` object); 99 2. switches to cross-replica context (in which case this will return 100 `None`) when entering a `with tf.distribute.Strategy.scope():` block; 101 3. switches to a (non-default) replica context inside 102 `strategy.experimental_run_v2(fn, ...)`; 103 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 104 inside `merge_fn` you are back in the cross-replica context (and again 105 this function will return `None`). 106 107 Most `tf.distribute.Strategy` methods may only be executed in 108 a cross-replica context, in a replica context you should use the 109 API of the `tf.distribute.ReplicaContext` object returned by this 110 method instead. 111 112 ``` 113 assert tf.distribute.get_replica_context() is not None # default 114 with strategy.scope(): 115 assert tf.distribute.get_replica_context() is None 116 117 def f(): 118 replica_context = tf.distribute.get_replica_context() # for strategy 119 assert replica_context is not None 120 tf.print("Replica id: ", replica_context.replica_id_in_sync_group, 121 " of ", replica_context.num_replicas_in_sync) 122 123 strategy.experimental_run_v2(f) 124 ``` 125 126 Returns: 127 The current `tf.distribute.ReplicaContext` object when in a replica context 128 scope, else `None`. 129 130 Within a particular block, exactly one of these two things will be true: 131 132 * `get_replica_context()` returns non-`None`, or 133 * `tf.distribute.is_cross_replica_context()` returns True. 134 """ 135 return _get_per_thread_mode().replica_context 136 137 138def get_cross_replica_context(): 139 """Returns the current tf.distribute.Strategy if in a cross-replica context. 140 141 DEPRECATED: Please use `in_cross_replica_context()` and 142 `get_strategy()` instead. 143 144 Returns: 145 Returns the current `tf.distribute.Strategy` object in a cross-replica 146 context, or `None`. 147 148 Exactly one of `get_replica_context()` and `get_cross_replica_context()` 149 will return `None` in a particular block. 150 """ 151 return _get_per_thread_mode().cross_replica_context 152 153 154@tf_export("distribute.in_cross_replica_context") 155def in_cross_replica_context(): 156 """Returns `True` if in a cross-replica context. 157 158 See `tf.distribute.get_replica_context` for details. 159 160 ``` 161 assert not tf.distribute.in_cross_replica_context() 162 with strategy.scope(): 163 assert tf.distribute.in_cross_replica_context() 164 165 def f(): 166 assert not tf.distribute.in_cross_replica_context() 167 168 strategy.experimental_run_v2(f) 169 ``` 170 171 Returns: 172 `True` if in a cross-replica context (`get_replica_context()` returns 173 `None`), or `False` if in a replica context (`get_replica_context()` returns 174 non-`None`). 175 """ 176 return _get_per_thread_mode().cross_replica_context is not None 177 178 179@tf_export("distribute.get_strategy") 180def get_strategy(): 181 """Returns the current `tf.distribute.Strategy` object. 182 183 Typically only used in a cross-replica context: 184 185 ``` 186 if tf.distribute.in_cross_replica_context(): 187 strategy = tf.distribute.get_strategy() 188 ... 189 ``` 190 191 Returns: 192 A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block, 193 it returns `strategy`, otherwise it returns the default (single-replica) 194 `tf.distribute.Strategy` object. 195 """ 196 return _get_per_thread_mode().strategy 197 198 199@tf_export("distribute.has_strategy") 200def has_strategy(): 201 """Return if there is a current non-default `tf.distribute.Strategy`. 202 203 ``` 204 assert not tf.distribute.has_strategy() 205 with strategy.scope(): 206 assert tf.distribute.has_strategy() 207 ``` 208 209 Returns: 210 True if inside a `with strategy.scope():`. 211 """ 212 return get_strategy() is not _get_default_strategy() 213 214 215def get_strategy_and_replica_context(): 216 per_thread_mode = _get_per_thread_mode() 217 return (per_thread_mode.strategy, per_thread_mode.replica_context) 218 219 220@tf_export("distribute.experimental_set_strategy") 221def experimental_set_strategy(strategy): 222 """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`. 223 224 ``` 225 tf.distribute.experimental_set_strategy(strategy1) 226 f() 227 tf.distribute.experimental_set_strategy(strategy2) 228 g() 229 tf.distribute.experimental_set_strategy(None) 230 h() 231 ``` 232 233 is equivalent to: 234 235 ``` 236 with strategy1.scope(): 237 f() 238 with strategy2.scope(): 239 g() 240 h() 241 ``` 242 243 In general, you should use the `with strategy.scope():` API, but this 244 alternative may be convenient in notebooks where you would have to put 245 each cell in a `with strategy.scope():` block. 246 247 Note: This should only be called outside of any TensorFlow scope to 248 avoid improper nesting. 249 250 Args: 251 strategy: A `tf.distribute.Strategy` object or None. 252 253 Raises: 254 RuntimeError: If called inside a `with strategy.scope():`. 255 """ 256 old_scope = ops.get_default_graph()._global_distribute_strategy_scope # pylint: disable=protected-access 257 if old_scope is not None: 258 old_scope.__exit__(None, None, None) 259 ops.get_default_graph()._global_distribute_strategy_scope = None # pylint: disable=protected-access 260 if has_strategy(): 261 raise RuntimeError( 262 "Must not be called inside a `tf.distribute.Strategy` scope.") 263 if strategy is not None: 264 new_scope = strategy.scope() 265 new_scope.__enter__() 266 ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access 267 268 269# ------------------------------------------------------------------------------ 270# Defaults that are used when no tf.distribute.Strategy is explicitly created. 271# We create them lazily in a function so that we can workaround the circular 272# dependency on distribute_lib. See lazy loader at the top of this file. 273 274_defaults = { 275 "strategy": None, 276 "replica_context": None, 277 "replica_mode": None 278} 279# Note: These need to be different locks since _get_default_replica_context 280# calls _get_default_strategy inside its lock, and them using the same lock 281# can lead to deadlock. 282_default_strategy_lock = threading.Lock() 283_default_replica_context_lock = threading.Lock() 284_default_replica_mode_lock = threading.Lock() 285 286 287def _get_default_strategy(): 288 if _defaults["strategy"] is None: 289 # Avoid race condition causing two defaults to be created 290 with _default_strategy_lock: 291 if _defaults["strategy"] is None: 292 # pylint: disable=protected-access 293 # Make sure distribute_lib module is loaded by accessing some member. 294 _ = distribute_lib._creating_default_strategy_singleton 295 distribute_lib._creating_default_strategy_singleton = True 296 _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy() 297 distribute_lib._creating_default_strategy_singleton = False 298 # pylint: enable=protected-access 299 return _defaults["strategy"] 300 301 302def _get_default_replica_context(): 303 if _defaults["replica_context"] is None: 304 # Avoid race condition causing two defaults to be created 305 with _default_replica_context_lock: 306 if _defaults["replica_context"] is None: 307 _defaults["replica_context"] = distribute_lib.ReplicaContext( 308 _get_default_strategy(), replica_id_in_sync_group=0) 309 return _defaults["replica_context"] 310 311 312def _get_default_replica_mode(): 313 if _defaults["replica_mode"] is None: 314 # Avoid race condition causing two defaults to be created 315 with _default_replica_mode_lock: 316 if _defaults["replica_mode"] is None: 317 _defaults["replica_mode"] = _DefaultReplicaThreadMode() 318 return _defaults["replica_mode"] 319 320 321# Aliases for compatibility with old names. 322get_distribution_strategy = get_strategy 323has_distribution_strategy = has_strategy 324