1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to get tf.distribute.Strategy related contexts.""" 16 17import contextlib 18import threading 19 20from tensorflow.python import tf2 21from tensorflow.python.framework import ops 22from tensorflow.python.util.lazy_loader import LazyLoader 23from tensorflow.python.util.tf_export import tf_export 24 25 26# There is a circular dependency between this and the `distribute_lib` module. 27# So we load it lazily to work around this. 28distribute_lib = LazyLoader( 29 "distribute_lib", globals(), 30 "tensorflow.python.distribute.distribute_lib") 31 32# ------------------------------------------------------------------------------ 33# Internal API for setting the current thread mode as being either in a 34# replica or cross-replica context for a particular tf.distribute.Strategy. 35 36 37class _ThreadMode(object): 38 39 def __init__(self, dist, cross, replica): 40 self.strategy = dist 41 self.cross_replica_context = cross 42 self.replica_context = replica 43 44 45class _CrossReplicaThreadMode(_ThreadMode): 46 47 def __init__(self, strategy): 48 _ThreadMode.__init__(self, strategy, strategy, None) 49 50 51class _InReplicaThreadMode(_ThreadMode): 52 53 def __init__(self, replica_ctx): 54 _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx) 55 56 57def _push_per_thread_mode(context): 58 ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access 59 60 61def _pop_per_thread_mode(): 62 ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access 63 64 65class _DefaultReplicaThreadMode(_ThreadMode): 66 """Type of default value returned by `_get_per_thread_mode()`. 67 68 Used when the thread-local stack is empty. 69 """ 70 71 def __init__(self): 72 _ThreadMode.__init__(self, _get_default_strategy(), None, 73 _get_default_replica_context()) 74 75 76def _get_per_thread_mode(): 77 try: 78 return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access 79 except (AttributeError, IndexError): 80 return _get_default_replica_mode() 81 82 83_variable_sync_on_read_context = threading.local() 84 85 86@tf_export("__internal__.distribute.variable_sync_on_read_context", v1=[]) 87@contextlib.contextmanager 88def variable_sync_on_read_context(): 89 """A context that forces SyncOnReadVariable to aggregate upon reading. 90 91 This context is useful if one wants to read the aggregated value out of a 92 SyncOnReadVariable in replica context. By default the aggregation is turned 93 off per the definition of SyncOnReadVariable. 94 95 When reading a SyncOnReadVariable in cross-replica context, aggregation is 96 always turned on so there is no need for such context. 97 98 By reading a SyncOnReadVariable, we mean: 99 1. Convert the variable to a tensor using `convert_to_tensor`. 100 2. Calling `variable.value()` or `variable.read_value()`. 101 102 Example usage: 103 104 ``` 105 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 106 with strategy.scope(): 107 v = tf.Variable(1.0, synchronization=tf.VariableSynchronization.ON_READ, 108 aggregation=tf.VariableAggregation.SUM) 109 110 def replica_fn(): 111 return v + 10.0 112 113 non_aggregated = strategy.run(replica_fn) 114 print(non_aggregated) # PerReplica: {0: 11.0, 1: 11.0} 115 116 def replica_fn(): 117 with variable_sync_on_read_context(): 118 return v + 10.0 119 120 aggregated = strategy.run(replica_fn) 121 print(aggregated) # PerReplica: {0: 12.0, 1: 12.0} 122 ``` 123 124 Yields: 125 Context manager for aggregating SyncOnReadVariable upon reading. 126 """ 127 try: 128 _variable_sync_on_read_context.entered = True 129 yield 130 finally: 131 _variable_sync_on_read_context.entered = False 132 133 134def in_variable_sync_on_read_context(): 135 try: 136 return _variable_sync_on_read_context.entered 137 except AttributeError: 138 return False 139 140# ------------------------------------------------------------------------------ 141# Public API for accessing the current thread mode 142 143 144@tf_export("distribute.get_replica_context") 145def get_replica_context(): 146 """Returns the current `tf.distribute.ReplicaContext` or `None`. 147 148 Returns `None` if in a cross-replica context. 149 150 Note that execution: 151 152 1. starts in the default (single-replica) replica context (this function 153 will return the default `ReplicaContext` object); 154 2. switches to cross-replica context (in which case this will return 155 `None`) when entering a `with tf.distribute.Strategy.scope():` block; 156 3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`; 157 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 158 inside `merge_fn` you are back in the cross-replica context (and again 159 this function will return `None`). 160 161 Most `tf.distribute.Strategy` methods may only be executed in 162 a cross-replica context, in a replica context you should use the 163 API of the `tf.distribute.ReplicaContext` object returned by this 164 method instead. 165 166 ``` 167 assert tf.distribute.get_replica_context() is not None # default 168 with strategy.scope(): 169 assert tf.distribute.get_replica_context() is None 170 171 def f(): 172 replica_context = tf.distribute.get_replica_context() # for strategy 173 assert replica_context is not None 174 tf.print("Replica id: ", replica_context.replica_id_in_sync_group, 175 " of ", replica_context.num_replicas_in_sync) 176 177 strategy.run(f) 178 ``` 179 180 Returns: 181 The current `tf.distribute.ReplicaContext` object when in a replica context 182 scope, else `None`. 183 184 Within a particular block, exactly one of these two things will be true: 185 186 * `get_replica_context()` returns non-`None`, or 187 * `tf.distribute.is_cross_replica_context()` returns True. 188 """ 189 return _get_per_thread_mode().replica_context 190 191 192def get_cross_replica_context(): 193 """Returns the current tf.distribute.Strategy if in a cross-replica context. 194 195 DEPRECATED: Please use `in_cross_replica_context()` and 196 `get_strategy()` instead. 197 198 Returns: 199 Returns the current `tf.distribute.Strategy` object in a cross-replica 200 context, or `None`. 201 202 Exactly one of `get_replica_context()` and `get_cross_replica_context()` 203 will return `None` in a particular block. 204 """ 205 return _get_per_thread_mode().cross_replica_context 206 207 208@tf_export("distribute.in_cross_replica_context") 209def in_cross_replica_context(): 210 """Returns `True` if in a cross-replica context. 211 212 See `tf.distribute.get_replica_context` for details. 213 214 ``` 215 assert not tf.distribute.in_cross_replica_context() 216 with strategy.scope(): 217 assert tf.distribute.in_cross_replica_context() 218 219 def f(): 220 assert not tf.distribute.in_cross_replica_context() 221 222 strategy.run(f) 223 ``` 224 225 Returns: 226 `True` if in a cross-replica context (`get_replica_context()` returns 227 `None`), or `False` if in a replica context (`get_replica_context()` returns 228 non-`None`). 229 """ 230 return _get_per_thread_mode().cross_replica_context is not None 231 232 233@tf_export("distribute.get_strategy") 234def get_strategy(): 235 """Returns the current `tf.distribute.Strategy` object. 236 237 Typically only used in a cross-replica context: 238 239 ``` 240 if tf.distribute.in_cross_replica_context(): 241 strategy = tf.distribute.get_strategy() 242 ... 243 ``` 244 245 Returns: 246 A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block, 247 it returns `strategy`, otherwise it returns the default (single-replica) 248 `tf.distribute.Strategy` object. 249 """ 250 return _get_per_thread_mode().strategy 251 252 253@tf_export("distribute.has_strategy") 254def has_strategy(): 255 """Return if there is a current non-default `tf.distribute.Strategy`. 256 257 ``` 258 assert not tf.distribute.has_strategy() 259 with strategy.scope(): 260 assert tf.distribute.has_strategy() 261 ``` 262 263 Returns: 264 True if inside a `with strategy.scope():`. 265 """ 266 return get_strategy() is not _get_default_strategy() 267 268 269def get_strategy_and_replica_context(): 270 per_thread_mode = _get_per_thread_mode() 271 return (per_thread_mode.strategy, per_thread_mode.replica_context) 272 273 274@tf_export("distribute.experimental_set_strategy") 275def experimental_set_strategy(strategy): 276 """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`. 277 278 ``` 279 tf.distribute.experimental_set_strategy(strategy1) 280 f() 281 tf.distribute.experimental_set_strategy(strategy2) 282 g() 283 tf.distribute.experimental_set_strategy(None) 284 h() 285 ``` 286 287 is equivalent to: 288 289 ``` 290 with strategy1.scope(): 291 f() 292 with strategy2.scope(): 293 g() 294 h() 295 ``` 296 297 In general, you should use the `with strategy.scope():` API, but this 298 alternative may be convenient in notebooks where you would have to put 299 each cell in a `with strategy.scope():` block. 300 301 Note: This should only be called outside of any TensorFlow scope to 302 avoid improper nesting. 303 304 Args: 305 strategy: A `tf.distribute.Strategy` object or None. 306 307 Raises: 308 RuntimeError: If called inside a `with strategy.scope():`. 309 """ 310 old_scope = ops.get_default_graph()._global_distribute_strategy_scope # pylint: disable=protected-access 311 if old_scope is not None: 312 old_scope.__exit__(None, None, None) 313 ops.get_default_graph()._global_distribute_strategy_scope = None # pylint: disable=protected-access 314 if has_strategy(): 315 raise RuntimeError( 316 "Must not be called inside a `tf.distribute.Strategy` scope.") 317 if strategy is not None: 318 new_scope = strategy.scope() 319 new_scope.__enter__() 320 ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access 321 322 323# ------------------------------------------------------------------------------ 324# Internal helpers. 325 326 327@contextlib.contextmanager 328def enter_or_assert_strategy(strategy): 329 if has_strategy(): 330 _assert_strategy(strategy) 331 yield 332 else: 333 with strategy.scope(): 334 yield 335 336 337# ------------------------------------------------------------------------------ 338# Defaults that are used when no tf.distribute.Strategy is explicitly created. 339# We create them lazily in a function so that we can workaround the circular 340# dependency on distribute_lib. See lazy loader at the top of this file. 341 342_defaults = { 343 "strategy": None, 344 "replica_context": None, 345 "replica_mode": None 346} 347# Note: These need to be different locks since _get_default_replica_context 348# calls _get_default_strategy inside its lock, and them using the same lock 349# can lead to deadlock. 350_default_strategy_lock = threading.Lock() 351_default_replica_context_lock = threading.Lock() 352_default_replica_mode_lock = threading.Lock() 353 354 355def _assert_strategy(strategy): 356 if not has_strategy(): 357 raise RuntimeError('Need to be inside "with strategy.scope()" for %s' % 358 (strategy,)) 359 current_strategy = get_strategy() 360 if current_strategy is not strategy: 361 raise RuntimeError( 362 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 363 (current_strategy, strategy)) 364 365 366def _get_default_strategy(): 367 if _defaults["strategy"] is None: 368 # Avoid race condition causing two defaults to be created 369 with _default_strategy_lock: 370 if _defaults["strategy"] is None: 371 # pylint: disable=protected-access 372 # Make sure distribute_lib module is loaded by accessing some member. 373 _ = distribute_lib._creating_default_strategy_singleton 374 distribute_lib._creating_default_strategy_singleton = True 375 if tf2.enabled(): 376 _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy() 377 else: 378 _defaults["strategy"] = ( 379 distribute_lib._DefaultDistributionStrategyV1()) 380 distribute_lib._creating_default_strategy_singleton = False 381 # pylint: enable=protected-access 382 return _defaults["strategy"] 383 384 385def _get_default_replica_context(): 386 if _defaults["replica_context"] is None: 387 # Avoid race condition causing two defaults to be created 388 with _default_replica_context_lock: 389 if _defaults["replica_context"] is None: 390 # pylint: disable=protected-access 391 _defaults["replica_context"] = distribute_lib._DefaultReplicaContext( 392 _get_default_strategy(), replica_id_in_sync_group=0) 393 # pylint: enable=protected-access 394 return _defaults["replica_context"] 395 396 397def _get_default_replica_mode(): 398 if _defaults["replica_mode"] is None: 399 # Avoid race condition causing two defaults to be created 400 with _default_replica_mode_lock: 401 if _defaults["replica_mode"] is None: 402 _defaults["replica_mode"] = _DefaultReplicaThreadMode() 403 return _defaults["replica_mode"] 404 405 406# Aliases for compatibility with old names. 407get_distribution_strategy = get_strategy 408has_distribution_strategy = has_strategy 409