1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to get tf.distribute.Strategy related contexts.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import threading 23 24from tensorflow.python import tf2 25from tensorflow.python.framework import ops 26from tensorflow.python.util.lazy_loader import LazyLoader 27from tensorflow.python.util.tf_export import tf_export 28 29 30# There is a circular dependency between this and the `distribute_lib` module. 31# So we load it lazily to work around this. 32distribute_lib = LazyLoader( 33 "distribute_lib", globals(), 34 "tensorflow.python.distribute.distribute_lib") 35 36# ------------------------------------------------------------------------------ 37# Internal API for setting the current thread mode as being either in a 38# replica or cross-replica context for a particular tf.distribute.Strategy. 39 40 41class _ThreadMode(object): 42 43 def __init__(self, dist, cross, replica): 44 self.strategy = dist 45 self.cross_replica_context = cross 46 self.replica_context = replica 47 48 49class _CrossReplicaThreadMode(_ThreadMode): 50 51 def __init__(self, strategy): 52 _ThreadMode.__init__(self, strategy, strategy, None) 53 54 55class _InReplicaThreadMode(_ThreadMode): 56 57 def __init__(self, replica_ctx): 58 _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx) 59 60 61def _push_per_thread_mode(context): 62 ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access 63 64 65def _pop_per_thread_mode(): 66 ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access 67 68 69class _DefaultReplicaThreadMode(_ThreadMode): 70 """Type of default value returned by `_get_per_thread_mode()`. 71 72 Used when the thread-local stack is empty. 73 """ 74 75 def __init__(self): 76 _ThreadMode.__init__(self, _get_default_strategy(), None, 77 _get_default_replica_context()) 78 79 80def _get_per_thread_mode(): 81 try: 82 return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access 83 except (AttributeError, IndexError): 84 return _get_default_replica_mode() 85 86 87_variable_sync_on_read_context = threading.local() 88 89 90@tf_export("__internal__.distribute.variable_sync_on_read_context", v1=[]) 91@contextlib.contextmanager 92def variable_sync_on_read_context(): 93 """A context that forces SyncOnReadVariable to aggregate upon reading. 94 95 This context is useful if one wants to read the aggregated value out of a 96 SyncOnReadVariable in replica context. By default the aggregation is turned 97 off per the definition of SyncOnReadVariable. 98 99 When reading a SyncOnReadVariable in cross-replica context, aggregation is 100 always turned on so there is no need for such context. 101 102 By reading a SyncOnReadVariable, we mean: 103 1. Convert the variable to a tensor using `convert_to_tensor`. 104 2. Calling `variable.value()` or `variable.read_value()`. 105 106 Example usage: 107 108 ``` 109 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 110 with strategy.scope(): 111 v = tf.Variable(1.0, synchronization=tf.VariableSynchronization.ON_READ, 112 aggregation=tf.VariableAggregation.SUM) 113 114 def replica_fn(): 115 return v + 10.0 116 117 non_aggregated = strategy.run(replica_fn) 118 print(non_aggregated) # PerReplica: {0: 11.0, 1: 11.0} 119 120 def replica_fn(): 121 with variable_sync_on_read_context(): 122 return v + 10.0 123 124 aggregated = strategy.run(replica_fn) 125 print(aggregated) # PerReplica: {0: 12.0, 1: 12.0} 126 ``` 127 128 Yields: 129 Context manager for aggregating SyncOnReadVariable upon reading. 130 """ 131 try: 132 _variable_sync_on_read_context.entered = True 133 yield 134 finally: 135 _variable_sync_on_read_context.entered = False 136 137 138def in_variable_sync_on_read_context(): 139 try: 140 return _variable_sync_on_read_context.entered 141 except AttributeError: 142 return False 143 144# ------------------------------------------------------------------------------ 145# Public API for accessing the current thread mode 146 147 148@tf_export("distribute.get_replica_context") 149def get_replica_context(): 150 """Returns the current `tf.distribute.ReplicaContext` or `None`. 151 152 Returns `None` if in a cross-replica context. 153 154 Note that execution: 155 156 1. starts in the default (single-replica) replica context (this function 157 will return the default `ReplicaContext` object); 158 2. switches to cross-replica context (in which case this will return 159 `None`) when entering a `with tf.distribute.Strategy.scope():` block; 160 3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`; 161 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 162 inside `merge_fn` you are back in the cross-replica context (and again 163 this function will return `None`). 164 165 Most `tf.distribute.Strategy` methods may only be executed in 166 a cross-replica context, in a replica context you should use the 167 API of the `tf.distribute.ReplicaContext` object returned by this 168 method instead. 169 170 ``` 171 assert tf.distribute.get_replica_context() is not None # default 172 with strategy.scope(): 173 assert tf.distribute.get_replica_context() is None 174 175 def f(): 176 replica_context = tf.distribute.get_replica_context() # for strategy 177 assert replica_context is not None 178 tf.print("Replica id: ", replica_context.replica_id_in_sync_group, 179 " of ", replica_context.num_replicas_in_sync) 180 181 strategy.run(f) 182 ``` 183 184 Returns: 185 The current `tf.distribute.ReplicaContext` object when in a replica context 186 scope, else `None`. 187 188 Within a particular block, exactly one of these two things will be true: 189 190 * `get_replica_context()` returns non-`None`, or 191 * `tf.distribute.is_cross_replica_context()` returns True. 192 """ 193 return _get_per_thread_mode().replica_context 194 195 196def get_cross_replica_context(): 197 """Returns the current tf.distribute.Strategy if in a cross-replica context. 198 199 DEPRECATED: Please use `in_cross_replica_context()` and 200 `get_strategy()` instead. 201 202 Returns: 203 Returns the current `tf.distribute.Strategy` object in a cross-replica 204 context, or `None`. 205 206 Exactly one of `get_replica_context()` and `get_cross_replica_context()` 207 will return `None` in a particular block. 208 """ 209 return _get_per_thread_mode().cross_replica_context 210 211 212@tf_export("distribute.in_cross_replica_context") 213def in_cross_replica_context(): 214 """Returns `True` if in a cross-replica context. 215 216 See `tf.distribute.get_replica_context` for details. 217 218 ``` 219 assert not tf.distribute.in_cross_replica_context() 220 with strategy.scope(): 221 assert tf.distribute.in_cross_replica_context() 222 223 def f(): 224 assert not tf.distribute.in_cross_replica_context() 225 226 strategy.run(f) 227 ``` 228 229 Returns: 230 `True` if in a cross-replica context (`get_replica_context()` returns 231 `None`), or `False` if in a replica context (`get_replica_context()` returns 232 non-`None`). 233 """ 234 return _get_per_thread_mode().cross_replica_context is not None 235 236 237@tf_export("distribute.get_strategy") 238def get_strategy(): 239 """Returns the current `tf.distribute.Strategy` object. 240 241 Typically only used in a cross-replica context: 242 243 ``` 244 if tf.distribute.in_cross_replica_context(): 245 strategy = tf.distribute.get_strategy() 246 ... 247 ``` 248 249 Returns: 250 A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block, 251 it returns `strategy`, otherwise it returns the default (single-replica) 252 `tf.distribute.Strategy` object. 253 """ 254 return _get_per_thread_mode().strategy 255 256 257@tf_export("distribute.has_strategy") 258def has_strategy(): 259 """Return if there is a current non-default `tf.distribute.Strategy`. 260 261 ``` 262 assert not tf.distribute.has_strategy() 263 with strategy.scope(): 264 assert tf.distribute.has_strategy() 265 ``` 266 267 Returns: 268 True if inside a `with strategy.scope():`. 269 """ 270 return get_strategy() is not _get_default_strategy() 271 272 273def get_strategy_and_replica_context(): 274 per_thread_mode = _get_per_thread_mode() 275 return (per_thread_mode.strategy, per_thread_mode.replica_context) 276 277 278@tf_export("distribute.experimental_set_strategy") 279def experimental_set_strategy(strategy): 280 """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`. 281 282 ``` 283 tf.distribute.experimental_set_strategy(strategy1) 284 f() 285 tf.distribute.experimental_set_strategy(strategy2) 286 g() 287 tf.distribute.experimental_set_strategy(None) 288 h() 289 ``` 290 291 is equivalent to: 292 293 ``` 294 with strategy1.scope(): 295 f() 296 with strategy2.scope(): 297 g() 298 h() 299 ``` 300 301 In general, you should use the `with strategy.scope():` API, but this 302 alternative may be convenient in notebooks where you would have to put 303 each cell in a `with strategy.scope():` block. 304 305 Note: This should only be called outside of any TensorFlow scope to 306 avoid improper nesting. 307 308 Args: 309 strategy: A `tf.distribute.Strategy` object or None. 310 311 Raises: 312 RuntimeError: If called inside a `with strategy.scope():`. 313 """ 314 old_scope = ops.get_default_graph()._global_distribute_strategy_scope # pylint: disable=protected-access 315 if old_scope is not None: 316 old_scope.__exit__(None, None, None) 317 ops.get_default_graph()._global_distribute_strategy_scope = None # pylint: disable=protected-access 318 if has_strategy(): 319 raise RuntimeError( 320 "Must not be called inside a `tf.distribute.Strategy` scope.") 321 if strategy is not None: 322 new_scope = strategy.scope() 323 new_scope.__enter__() 324 ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access 325 326 327# ------------------------------------------------------------------------------ 328# Internal helpers. 329 330 331@contextlib.contextmanager 332def enter_or_assert_strategy(strategy): 333 if has_strategy(): 334 _assert_strategy(strategy) 335 yield 336 else: 337 with strategy.scope(): 338 yield 339 340 341# ------------------------------------------------------------------------------ 342# Defaults that are used when no tf.distribute.Strategy is explicitly created. 343# We create them lazily in a function so that we can workaround the circular 344# dependency on distribute_lib. See lazy loader at the top of this file. 345 346_defaults = { 347 "strategy": None, 348 "replica_context": None, 349 "replica_mode": None 350} 351# Note: These need to be different locks since _get_default_replica_context 352# calls _get_default_strategy inside its lock, and them using the same lock 353# can lead to deadlock. 354_default_strategy_lock = threading.Lock() 355_default_replica_context_lock = threading.Lock() 356_default_replica_mode_lock = threading.Lock() 357 358 359def _assert_strategy(strategy): 360 if not has_strategy(): 361 raise RuntimeError('Need to be inside "with strategy.scope()" for %s' % 362 (strategy,)) 363 current_strategy = get_strategy() 364 if current_strategy is not strategy: 365 raise RuntimeError( 366 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 367 (current_strategy, strategy)) 368 369 370def _get_default_strategy(): 371 if _defaults["strategy"] is None: 372 # Avoid race condition causing two defaults to be created 373 with _default_strategy_lock: 374 if _defaults["strategy"] is None: 375 # pylint: disable=protected-access 376 # Make sure distribute_lib module is loaded by accessing some member. 377 _ = distribute_lib._creating_default_strategy_singleton 378 distribute_lib._creating_default_strategy_singleton = True 379 if tf2.enabled(): 380 _defaults["strategy"] = distribute_lib._DefaultDistributionStrategy() 381 else: 382 _defaults["strategy"] = ( 383 distribute_lib._DefaultDistributionStrategyV1()) 384 distribute_lib._creating_default_strategy_singleton = False 385 # pylint: enable=protected-access 386 return _defaults["strategy"] 387 388 389def _get_default_replica_context(): 390 if _defaults["replica_context"] is None: 391 # Avoid race condition causing two defaults to be created 392 with _default_replica_context_lock: 393 if _defaults["replica_context"] is None: 394 # pylint: disable=protected-access 395 _defaults["replica_context"] = distribute_lib._DefaultReplicaContext( 396 _get_default_strategy(), replica_id_in_sync_group=0) 397 # pylint: enable=protected-access 398 return _defaults["replica_context"] 399 400 401def _get_default_replica_mode(): 402 if _defaults["replica_mode"] is None: 403 # Avoid race condition causing two defaults to be created 404 with _default_replica_mode_lock: 405 if _defaults["replica_mode"] is None: 406 _defaults["replica_mode"] = _DefaultReplicaThreadMode() 407 return _defaults["replica_mode"] 408 409 410# Aliases for compatibility with old names. 411get_distribution_strategy = get_strategy 412has_distribution_strategy = has_strategy 413